diff --git a/.appveyor.yml b/.appveyor.yml index d90d4ba724..f4f56fa159 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,3 +1,5 @@ +skip_branch_with_pr: true + environment: matrix: - LIB_TYPE: shared diff --git a/.gitignore b/.gitignore index f883af441e..1ee6c82355 100644 --- a/.gitignore +++ b/.gitignore @@ -54,8 +54,10 @@ GPATH GRTAGS GTAGS -# Windows Build -build/* +# cmake builds +build_*/* + +# Windows build bin/* *.dll *.lib diff --git a/.travis.yml b/.travis.yml index a61a879fa1..555e9a11a2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,6 +48,13 @@ matrix: CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ \ PACKAGES="gcc-aarch64-linux-gnu g++-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ TESTSUITE_WRAPPER="qemu-aarch64 -L /usr/aarch64-linux-gnu/" + # Apple M1 (firestorm) build and fast testsuite (qemu) + - os: linux + compiler: aarch64-linux-gnu-gcc + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="firestorm" \ + CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ \ + PACKAGES="gcc-aarch64-linux-gnu g++-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ + TESTSUITE_WRAPPER="qemu-aarch64 -L /usr/aarch64-linux-gnu/" # armsve build and fast testsuite (qemu) - os: linux compiler: aarch64-linux-gnu-gcc-10 diff --git a/CMakeLists.txt b/CMakeLists.txt index 787f831452..75732f8d0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,38 @@ -##Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ -cmake_minimum_required(VERSION 3.15.0) + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] + +cmake_minimum_required(VERSION 3.22.0) if(WIN32) project(AOCL-LibBlis LANGUAGES C CXX) else() @@ -34,7 +66,7 @@ if(WIN32) else() set(BLIS_CONFIG_FAMILY "" CACHE STRING "Set the configuration family for which the BLIS library will be built.") endif() -set_property(CACHE BLIS_CONFIG_FAMILY PROPERTY STRINGS "auto" "generic" "zen" "zen2" "zen3" "zen4" "amdzen") +set_property(CACHE BLIS_CONFIG_FAMILY PROPERTY STRINGS "auto" "generic" "zen" "zen2" "zen3" "zen4" "zen5" "amdzen") # Throw an error if CMake was configured with a configuration which is not enabled yet. if(NOT ((BLIS_CONFIG_FAMILY STREQUAL auto) OR (BLIS_CONFIG_FAMILY STREQUAL generic) OR @@ -42,10 +74,11 @@ if(NOT ((BLIS_CONFIG_FAMILY STREQUAL auto) OR (BLIS_CONFIG_FAMILY STREQUAL zen2) OR (BLIS_CONFIG_FAMILY STREQUAL zen3) OR (BLIS_CONFIG_FAMILY STREQUAL zen4) OR + (BLIS_CONFIG_FAMILY STREQUAL zen5) OR (BLIS_CONFIG_FAMILY STREQUAL amdzen))) message(FATAL_ERROR "Configuration for ${BLIS_CONFIG_FAMILY} is not supported. \ Please re-run cmake and specify one of the following configurations for BLIS_CONFIG_FAMILY: \ - auto, zen, zen2, zen3, zen4, amdzen, generic.") + auto, zen, zen2, zen3, zen4, zen5, amdzen, generic.") endif() # automatic hardware detection @@ -69,7 +102,7 @@ if(BLIS_CONFIG_FAMILY STREQUAL "auto") COMPILE_DEFINITIONS -I${frame_include} -I${base_include} -I${thread_include} -DBLIS_CONFIGURETIME_CPUID -DBLIS_CONFIG_SKX -DBLIS_CONFIG_KNL -DBLIS_CONFIG_HASWELL -DBLIS_CONFIG_SANDYBRIDGE -DBLIS_CONFIG_PENRYN - -DBLIS_CONFIG_ZEN4 -DBLIS_CONFIG_ZEN3 -DBLIS_CONFIG_ZEN2 -DBLIS_CONFIG_ZEN + -DBLIS_CONFIG_ZEN5 -DBLIS_CONFIG_ZEN4 -DBLIS_CONFIG_ZEN3 -DBLIS_CONFIG_ZEN2 -DBLIS_CONFIG_ZEN -DBLIS_CONFIG_EXCAVATOR -DBLIS_CONFIG_STEAMROLLER -DBLIS_CONFIG_PILEDRIVER -DBLIS_CONFIG_BULLDOZER -DBLIS_CONFIG_THUNDERX2 -DBLIS_CONFIG_CORTEXA57 -DBLIS_CONFIG_CORTEXA15 -DBLIS_CONFIG_CORTEXA9 @@ -81,7 +114,8 @@ if(BLIS_CONFIG_FAMILY STREQUAL "auto") if( NOT(${HARDWARE_ARCH} STREQUAL zen OR ${HARDWARE_ARCH} STREQUAL zen2 OR ${HARDWARE_ARCH} STREQUAL zen3 OR - ${HARDWARE_ARCH} STREQUAL zen4) ) + ${HARDWARE_ARCH} STREQUAL zen4 OR + ${HARDWARE_ARCH} STREQUAL zen5) ) set(BLIS_CONFIG_FAMILY "generic") message(WARNING "Only AMD zen architectures are supported. \ Detected ${HARDWARE_ARCH} hardware. Defaulting to generic configuration.") @@ -142,6 +176,8 @@ foreach(KERN ${KERNEL_LIST}) set(KERNEL_LIST_DEFINES "${KERNEL_LIST_DEFINES}#define BLIS_KERNELS_${UCONF}\n") endforeach() + + #------------------------------------ # Option Setting #------------------------------------ @@ -151,9 +187,11 @@ if(WIN32) option(ENABLE_UPPERCASE_API "Export APIs with uppercase." OFF) # Setting path to OpenMP runtime. set(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library path") +else(WIN32) + set(OpenMP_libomp_LIBRARY "" CACHE STRING "openmp library path") endif() # Debug & Release flags option setting is only available for Linux. On Windows the default flags are used. -if(NOT WIN32) +if(NOT MSVC) set(ENABLE_DEBUG "off" CACHE STRING "Enable debugging symbols in the library.") set_property(CACHE ENABLE_DEBUG PROPERTY STRINGS "off" "noopt" "opt") if( NOT ((ENABLE_DEBUG STREQUAL "off") OR (ENABLE_DEBUG STREQUAL "noopt") OR (ENABLE_DEBUG STREQUAL "opt")) ) @@ -178,8 +216,16 @@ if(NOT WIN32) set(CMAKE_BUILD_TYPE "") endif() endif() -# Build shared libraries by default -option(BUILD_SHARED_LIBS "Build shared libraries (.dll/.so) instead of static ones (.lib/.a)" ON) + +if(WIN32) + # Build shared libraries only by default + option(BUILD_SHARED_LIBS "Build shared libraries (.dll/.lib) instead of static ones (.lib/.a)" ON) +else() + # Build both shared and static libraries by default + option(BUILD_SHARED_LIBS "Build shared libraries (.dll/.lib)" ON) + option(BUILD_STATIC_LIBS "Build static libraries (.lib/.a)" ON) + option(TEST_WITH_SHARED "If both static and shared libraries are build, run the tests linking the shared library." OFF) +endif() option(ENABLE_SYSTEM "Check if we are building with or without operating system support" ON) set(ENABLE_THREADING "no" CACHE STRING "the threading flag") if(WIN32) @@ -273,12 +319,23 @@ else() endif() set(RENAME_BLIS_ARCH_TYPE "BLIS_ARCH_TYPE" CACHE STRING "BLIS_ARCH_TYPE env var renamed to supplied value") set(RENAME_BLIS_MODEL_TYPE "BLIS_MODEL_TYPE" CACHE STRING "BLIS_MODEL_TYPE env var renamed to supplied value") -if(NOT WIN32) - set(ENABLE_ADDON "" CACHE STRING "Configure with specific addons using a ';'-separated list") +if(ENABLE_ADDON) + if((NOT WIN32) OR + (WIN32 AND (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") AND NOT (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "18.0"))) + set(ENABLE_ADDON "" CACHE STRING "Configure with specific addons using a ';'-separated list") + else() + message(FATAL_ERROR "On Windows, aocl_gemm addon requires Clang version at least 18.0. Current version: ${CMAKE_CXX_COMPILER_VERSION}") + endif() endif() set(ENABLE_SANDBOX "" CACHE STRING "Enable a separate sandbox implementation of gemm.") # Do not let ENABLE_SANDBOX appear on cmake-gui since the functionality is not yet implemented. mark_as_advanced(ENABLE_SANDBOX) +if(NOT WIN32) + option(ENABLE_COVERAGE "Enable Code Coverage using gcov(only GCC/Debug build)" OFF) +endif() +if(NOT WIN32) + option(ENABLE_ASAN "Enable Address Sanitizer (Debug build)" OFF) +endif() #------------------------------------ # Check memkind @@ -314,7 +371,7 @@ file(STRINGS ${CMAKE_SOURCE_DIR}/version VERSION) # Get timestamp. string(TIMESTAMP BUILD_DATE "%Y%m%d") # Update using the timestamp. -set(VERSION_STRING "AOCL-BLIS ${VERSION} Build ${BUILD_DATE}") +set(VERSION_STRING "AOCL-BLAS ${VERSION} Build ${BUILD_DATE}") # Initial message. message(STATUS "Starting configuration of BLIS ${VERSION_STRING}.") # Check if the user requested a custom version string. @@ -337,23 +394,57 @@ list(GET SO_VERSION 1 SO_VERSION_MINOR) #------------------------------------ include(CMakePrintHelpers) message(STATUS "Printing CMake Configuration Options...") -cmake_print_variables(ENABLE_DEBUG) -# Initialize debug type, using the corresponding cache variable. -set(DEBUG_TYPE ${ENABLE_DEBUG}) -if(ENABLE_DEBUG STREQUAL "off") - message(" Debug symbols disabled.") -elseif(ENABLE_DEBUG STREQUAL "opt") - message(" Enabling debug symbols with optimizations.") -else() #ENABLE_DEBUG=noopt - message(" Enabling debug symbols; optimizations disabled.") -endif() -cmake_print_variables(BUILD_SHARED_LIBS) -if(BUILD_SHARED_LIBS) - message(" Building BLIS as a shared library.") - set(ENABLE_SHARED_01 1) +if(NOT MSVC) + cmake_print_variables(ENABLE_DEBUG) + # Initialize debug type, using the corresponding cache variable. + set(DEBUG_TYPE ${ENABLE_DEBUG}) + if(ENABLE_DEBUG STREQUAL "off") + message(" Debug symbols disabled.") + elseif(ENABLE_DEBUG STREQUAL "opt") + message(" Enabling debug symbols with optimizations.") + else() #ENABLE_DEBUG=noopt + message(" Enabling debug symbols; optimizations disabled.") + endif() +endif() +if(WIN32) + cmake_print_variables(BUILD_SHARED_LIBS) + if(BUILD_SHARED_LIBS) + message(" Building BLIS as a shared library.") + set(ENABLE_SHARED_01 1) + set(TEST_WITH_SHARED ON) + else() + message(" Building BLIS as a static library.") + set(ENABLE_SHARED_01 0) + set(BUILD_STATIC_LIBS ON) + set(TEST_WITH_SHARED OFF) + endif() else() - message(" Building BLIS as a static library.") - set(ENABLE_SHARED_01 0) + cmake_print_variables(BUILD_SHARED_LIBS) + cmake_print_variables(BUILD_STATIC_LIBS) + if(BUILD_SHARED_LIBS AND BUILD_STATIC_LIBS) + message(" Building BLIS as both static and shared libraries.") + set(ENABLE_SHARED_01 1) + cmake_print_variables(TEST_WITH_SHARED) + if(TEST_WITH_SHARED) + message(" Testing using shared library.") + else() + message(" Testing using static library.") + endif() + elseif(BUILD_STATIC_LIBS AND NOT BUILD_SHARED_LIBS) + message(" Building BLIS as a static library (shared library disabled).") + set(ENABLE_SHARED_01 0) + set(TEST_WITH_SHARED OFF) + cmake_print_variables(TEST_WITH_SHARED) + message(" Testing using static library.") + elseif(BUILD_SHARED_LIBS AND NOT BUILD_STATIC_LIBS) + message(" Building BLIS as a shared library (static library disabled).") + set(ENABLE_SHARED_01 1) + set(TEST_WITH_SHARED ON) + cmake_print_variables(TEST_WITH_SHARED) + message(" Testing using shared library.") + else() + message(FATAL_ERROR "Both static and shared libraries were disabled. Please enable one (or both) to continue.") + endif() endif() if(NOT WIN32) cmake_print_variables(EXPORT_SHARED) @@ -429,7 +520,8 @@ if(ENABLE_SBA_POOLS) message(" Internal memory pools for small blocks are enabled.") set(ENABLE_SBA_POOLS_01 1) else() - message(" Internal memory pools for small blocks are disabled.") + #message(" Internal memory pools for small blocks are disabled.") + message(FATAL_ERROR "Disabling memory pools for small blocks is currently disabled, awaiting fixes to this functionality.") set(ENABLE_SBA_POOLS_01 0) endif() cmake_print_variables(ENABLE_MEM_TRACING) @@ -507,7 +599,7 @@ if(ENABLE_MIXED_DT) else() message(" Mixed datatype optimizations requiring extra memory are disabled.") set(ENABLE_MIXED_DT_EXTRA_MEM_01 0) - endif() + endif() set(ENABLE_MIXED_DT_01 1) else() message(" Mixed datatype support is disabled.") @@ -574,24 +666,22 @@ if((INT_TYPE_SIZE STREQUAL "32") AND (BLAS_INT_TYPE_SIZE STREQUAL "64")) To avoid the possibility of truncation, we do not allow use of 64-bit integers in the BLAS API with 32-bit integers in BLIS. \ Please use a different configuration of integers.") endif() -if(NOT WIN32) - cmake_print_variables(ENABLE_ADDON) - if(ENABLE_ADDON STREQUAL "") - message(" Configuring with no addons.") - set(ENABLE_ADDONS_01 0) - else() - # Remove duplicates in the addon list, if they exist. - list(REMOVE_DUPLICATES ENABLE_ADDON) - message(" Configuring with addons:") - foreach(ADDON ${ENABLE_ADDON}) - message(" ${ADDON}") - if(NOT (EXISTS ${CMAKE_SOURCE_DIR}/addon/${ADDON})) - message(FATAL_ERROR "Requested addon sub-directory does not exist! Cannot continue. \ - *** Please verify addon existence and name.") - endif() - endforeach() - set(ENABLE_ADDONS_01 1) - endif() +cmake_print_variables(ENABLE_ADDON) +if(ENABLE_ADDON STREQUAL "") + message(" Configuring with no addons.") + set(ENABLE_ADDONS_01 0) +else() + # Remove duplicates in the addon list, if they exist. + list(REMOVE_DUPLICATES ENABLE_ADDON) + message(" Configuring with addons:") + foreach(ADDON ${ENABLE_ADDON}) + message(" ${ADDON}") + if(NOT (EXISTS ${CMAKE_SOURCE_DIR}/addon/${ADDON})) + message(FATAL_ERROR "Requested addon sub-directory does not exist! Cannot continue. \ + *** Please verify addon existence and name.") + endif() + endforeach() + set(ENABLE_ADDONS_01 1) endif() cmake_print_variables(ENABLE_SANDBOX) if(ENABLE_SANDBOX STREQUAL "") @@ -650,6 +740,40 @@ if(WIN32) message(" Export APIs with lowercase.") endif() endif() +if(NOT WIN32) + cmake_print_variables(ENABLE_COVERAGE) + if(ENABLE_COVERAGE) + if(NOT (${CMAKE_C_COMPILER_ID} MATCHES "GNU")) + message(FATAL_ERROR "Coverage is only supported for GNU/Linux GCC Debug build") + set(ENABLE_COVERAGE OFF) + endif() + if(NOT(ENABLE_DEBUG STREQUAL "noopt")) + message(WARNING "Coverage is only supported for debug builds, but ENABLE_DEBUG=noopt was set.\ + Disabling optimizations to generate the code coverage report.") + set(ENABLE_DEBUG "noopt") + set(DEBUG_TYPE ${ENABLE_DEBUG}) + endif() + endif() + if(ENABLE_COVERAGE) + message(" Code Coverage is enabled.") + else() + message(" Code Coverage is disabled.") + endif() +endif() + +if(NOT WIN32) + cmake_print_variables(ENABLE_ASAN) + if(ENABLE_ASAN) + if(NOT (${CMAKE_C_COMPILER_ID} MATCHES "Clang")) + message(FATAL_ERROR "ASAN is supported only for Clang/Linux" ) + endif() + endif() + if(ENABLE_ASAN) + message(" Address Sanitizer is enabled.") + else() + message(" Address Sanitizer is disabled.") + endif() +endif() # Initialize threading model, using the corresponding cache variable. set(THREADING_MODEL ${ENABLE_THREADING}) @@ -668,11 +792,6 @@ configure_file(build/cmake/bli_config.h.in ${PROJECT_BINARY_DIR}/bli_config.h) # Create a list of #includes, one for each addon in addon_list. set(ADDON_LIST_INCLUDES "") foreach(ADDON ${ENABLE_ADDON}) - if(ADDON STREQUAL "aocl_gemm") - if(("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 11.0.0)) - message(FATAL_ERROR "aocl_gemm addon requires a gcc version 11.0.0 or higher.") - endif() - endif() set(ADDON_HEADER "\"${ADDON}.h\"") set(ADDON_LIST_INCLUDES "${ADDON_LIST_INCLUDES}#include ${ADDON_HEADER}\n") endforeach() @@ -773,8 +892,9 @@ add_custom_command(OUTPUT ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/bl "${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h" "${PROJECT_BINARY_DIR}/include" "${ALL_HEADER_PATHS_STRING}" - COMMENT "Generating monolithic blis header file: ${CMAKE_SOURCE_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h" + COMMENT "Generating monolithic blis header file: ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h" DEPENDS ${ALL_HEADER_FILES_LIST} + VERBATIM ) add_custom_target(flat-header DEPENDS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h) #-------------------------------------------- @@ -788,8 +908,9 @@ if(ENABLE_CBLAS) "${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h" "${PROJECT_BINARY_DIR}/${include}" "${ALL_HEADER_PATHS_STRING}" - COMMENT "Generating monolithic cblas header file: ${CMAKE_SOURCE_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h" + COMMENT "Generating monolithic cblas header file: ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h" DEPENDS ${ALL_HEADER_FILES_LIST} + VERBATIM ) add_custom_target(flat-cblas-header DEPENDS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h) endif() @@ -804,7 +925,7 @@ endif() # Define the external libraries we may potentially need at link-time. # Add libm only on Linux and only if Intel compiler is not used. -if((NOT WIN32) AND (NOT ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Intel"))) +if((NOT WIN32) AND (NOT ("${CMAKE_C_COMPILER_ID}" MATCHES "Intel"))) set(LIBM -lm) endif() set(LIBMEMKIND -lmemkind) @@ -821,6 +942,22 @@ if(ENABLE_MEMKIND STREQUAL "yes") list(APPEND LDFLAGS ${LIBMEMKIND}) endif() +#-------------------------------------------- +# Code-coverage flags +#-------------------------------------------- +if(ENABLE_COVERAGE AND (NOT WIN32)) + set(COVERAGE_FLAGS "-fprofile-arcs -ftest-coverage") + list(APPEND CMAKE_C_FLAGS ${COVERAGE_FLAGS}) +endif() + +#-------------------------------------------- +# Address Sanitizer flags +#-------------------------------------------- +if(ENABLE_ASAN AND (NOT WIN32)) + set(ASAN_FLAGS "-g -fsanitize=address") + list(APPEND CMAKE_C_FLAGS ${ASAN_FLAGS}) +endif() + #-------------------------------------------- # Configuration-agnostic flags #-------------------------------------------- @@ -834,7 +971,7 @@ if(NOT WIN32) endif() # Disable tautological comparision warnings in clang. -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") +if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") list(APPEND CWARNFLAGS -Wno-tautological-compare -Wno-pass-failed) endif() @@ -985,6 +1122,9 @@ foreach(ker ${KERNEL_LIST}) if(TARGET ${ker}_KERNELS) list(APPEND OBJECT_LIBRARIES $) endif() + if(TARGET ${ker}_LPGEMM_KERNELS) + list(APPEND OBJECT_LIBRARIES $) + endif() endforeach() # Add objects for reference kernels. foreach(conf ${CONFIG_LIST}) @@ -1026,7 +1166,7 @@ endif() # --- Library name and local paths --- # From old CMake if(WIN32) - add_definitions(-D_CRT_SECURE_NO_WARNINGS) + add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-D_CRT_SECURE_NO_DEPRECATE) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP${CMake_MSVC_PARALLEL}") @@ -1050,39 +1190,94 @@ if(NOT (THREADING_MODEL STREQUAL "no")) endif() endif() +set(LIBBLIS_STATIC ${LIBBLIS}) +set(LIBBLIS_SHARED ${LIBBLIS}) +if(WIN32) + string(APPEND LIBBLIS_SHARED -dll) +endif() + +# Set directories for installation of libraries and header files. +set(LIB_DIR ${CMAKE_INSTALL_PREFIX}/lib) +set(INCLUDE_DIR ${CMAKE_INSTALL_PREFIX}/include) +# Set LDFLAGS to be replaced in pc file. +set(LDFLAGS_STRING ${LDFLAGS}) +# Add OpenMP flags as required. +if(THREADING_MODEL STREQUAL "openmp") + list(APPEND LDFLAGS_STRING "${OpenMP_C_FLAGS}") +endif() +string(JOIN " " LDFLAGS_STRING ${LDFLAGS_STRING}) +if(NOT WIN32) + configure_file( + ${CMAKE_SOURCE_DIR}/build/cmake/aocl-blas.pc.in + ${CMAKE_BINARY_DIR}/aocl-blas.pc + @ONLY + ) +endif() +include(GNUInstallDirs) + if(BUILD_SHARED_LIBS) - if(WIN32) - string(APPEND LIBBLIS -dll) - endif() # Build shared library. - add_library(libblis SHARED ${OBJECT_LIBRARIES}) - target_link_libraries(libblis PRIVATE ${LDFLAGS}) - set_target_properties(libblis PROPERTIES LINKER_LANGUAGE C VERSION ${VERSION} SOVERSION ${SO_VERSION_MAJOR}) - set_target_properties(libblis PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(libblis-shared SHARED ${OBJECT_LIBRARIES}) + target_link_libraries(libblis-shared PRIVATE ${LDFLAGS}) + set_target_properties(libblis-shared PROPERTIES LINKER_LANGUAGE C VERSION ${VERSION} SOVERSION ${SO_VERSION_MAJOR}) + set_target_properties(libblis-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) if(THREADING_MODEL STREQUAL "openmp") - target_link_libraries(libblis PRIVATE OpenMP::OpenMP_C) + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(libblis-shared PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(libblis-shared PRIVATE OpenMP::OpenMP_C) + endif() endif() -else() + add_dependencies(libblis-shared flat-header) + if(ENABLE_CBLAS) + add_dependencies(libblis-shared flat-cblas-header) + endif() + # Add headers as a property to the library. + set_target_properties(libblis-shared PROPERTIES PUBLIC_HEADER "${BLIS_PUBLIC_HEADERS}") + set_target_properties(libblis-shared PROPERTIES OUTPUT_NAME ${LIBBLIS_SHARED}) + # Install targets for shared. + install(TARGETS libblis-shared LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}/include) + set(libblis_depends libblis-shared) +endif() +if(BUILD_STATIC_LIBS OR NOT BUILD_SHARED_LIBS) # Build static library. - add_library(libblis STATIC ${OBJECT_LIBRARIES}) - set_target_properties(libblis PROPERTIES LINKER_LANGUAGE C) + add_library(libblis-static STATIC ${OBJECT_LIBRARIES}) + set_target_properties(libblis-static PROPERTIES LINKER_LANGUAGE C) + # Setting this for static to fix issues where test programs built with gcc 9.4.0 fail to link versions of BLIS build with AOCC 4.0.0. + set_target_properties(libblis-static PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_dependencies(libblis-static flat-header) + if(ENABLE_CBLAS) + add_dependencies(libblis-static flat-cblas-header) + endif() + # Add headers as a property to the library. + set_target_properties(libblis-static PROPERTIES PUBLIC_HEADER "${BLIS_PUBLIC_HEADERS}") + set_target_properties(libblis-static PROPERTIES OUTPUT_NAME ${LIBBLIS_STATIC}) + # Install targets. + install(TARGETS libblis-static LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}/include) + list(APPEND libblis_depends libblis-static) endif() -add_dependencies(libblis flat-header) -if(ENABLE_CBLAS) - add_dependencies(libblis flat-cblas-header) + +if(NOT WIN32) + # Install package-config file. + install(FILES ${CMAKE_BINARY_DIR}/aocl-blas.pc DESTINATION ${CMAKE_INSTALL_PREFIX}/lib/pkgconfig) endif() -# Add headers as a property to the library. -set_target_properties(libblis PROPERTIES PUBLIC_HEADER "${BLIS_PUBLIC_HEADERS}") -set_target_properties(libblis PROPERTIES OUTPUT_NAME ${LIBBLIS}) -# Install targets. -install(TARGETS libblis LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib - ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib - RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/lib - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}/include) +# Set libblis to the shared or static libblis depending on the option setting. +if(TEST_WITH_SHARED) + set(libblis_link libblis-shared) +else() + set(libblis_link libblis-static) +endif() # --- Primary targets --- -add_custom_target(libs DEPENDS libblis) +add_custom_target(libblis DEPENDS ${libblis_depends}) +add_custom_target(libs DEPENDS ${libblis}) # Multiple BLIS API testing targets. Result files are generated in ${CMAKE_BINARY_DIR}/testsuite. add_subdirectory(testsuite EXCLUDE_FROM_ALL) @@ -1095,16 +1290,49 @@ if(ENABLE_BLAS) add_subdirectory(blastest EXCLUDE_FROM_ALL) endif() +if(ENABLE_BLAS AND WIN32 AND BUILD_SHARED_LIBS) +set(DETAILED_BLATEST_MESSAGE "Details: Level2 and level3 API tests define a custom version of xerbla_() to test the error codes. \ +On Linux and on Windows/static versions of BLIS library, the custom xerbla_() gets called inside the library\ +due to the linking process and all tests work. On Windows/shared version of the library, symbol resolution\ +happens at load-time so the blis implementation of xerbla_() gets called instead of the custom one. \ +That causes errors when the tests are run which are independent of the BLIS library. \ +Please use static builds only on Windows.") +endif() + # Add generic testing target `test`. set(available_testsuites checkblis) if(ENABLE_BLAS) list(APPEND available_testsuites checkblas) endif() -add_custom_target(test DEPENDS ${available_testsuites}) + +if(WIN32 AND BUILD_SHARED_LIBS) + if(ENABLE_BLAS) + set(TEST_WARNING "Target `test` depends only on target `checkblis` because `checkblas` target is not available on Windows for shared builds of BLIS. ") + endif() +else() + if(ENABLE_BLAS) + list(APPEND available_testsuites checkblas) + endif() +endif() +add_custom_target(tests + DEPENDS ${available_testsuites} + COMMENT "Running target `test`. ${TEST_WARNING} ${DETAILED_BLATEST_MESSAGE}") # Add generic testing target `check`. set(available_testsuites checkblis-fast) -if(ENABLE_BLAS) - list(APPEND available_testsuites checkblas) +if(WIN32 AND BUILD_SHARED_LIBS) + if(ENABLE_BLAS) + set(CHECK_WARNING "Target `check` depends only on target `checkblis-fast` because `checkblas` target is not available on Windows for shared builds of BLIS. ") + endif() +else() + if(ENABLE_BLAS) + list(APPEND available_testsuites checkblas) + endif() endif() -add_custom_target(check DEPENDS ${available_testsuites}) \ No newline at end of file +add_custom_target(check + DEPENDS ${available_testsuites} + COMMENT "Running target `check`. ${CHECK_WARNING} ${DETAILED_BLATEST_MESSAGE}") + +add_subdirectory(bench EXCLUDE_FROM_ALL) + +add_subdirectory(test EXCLUDE_FROM_ALL) diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 0000000000..1fd45a56c3 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,16 @@ +{ + "version": 6, + "cmakeMinimumRequired": { + "major": 3, + "minor": 25, + "patch": 0 + }, + "include": [ + "build/cmake/presets/linux-make-clang.json", + "build/cmake/presets/linux-make-gcc.json", + "build/cmake/presets/linux-make.json", + "build/cmake/presets/linux-ninja.json", + "build/cmake/presets/win-msvc.json", + "build/cmake/presets/win-ninja.json" + ] +} diff --git a/CREDITS b/CREDITS index bcbc889ce0..7f74ad29f6 100644 --- a/CREDITS +++ b/CREDITS @@ -92,6 +92,7 @@ but many others have contributed code and feedback, including Nathaniel Smith @njsmith Shaden Smith @ShadenSmith Tyler Smith @tlrmchlsmth (The University of Texas at Austin) + Snehith @ArcadioN09 Paul Springer @springer13 (RWTH Aachen University) Adam J. Stewart @adamjstewart (University of Illinois at Urbana-Champaign) Vladimir Sukarev diff --git a/LICENSE b/LICENSE index f05ca1125c..9e6434dc38 100644 --- a/LICENSE +++ b/LICENSE @@ -1,43 +1,129 @@ -NOTE: Portions of this project's code are copyrighted by - - The University of Texas at Austin - -while other portions are copyrighted by - - Hewlett Packard Enterprise Development LP - Advanced Micro Devices, Inc. - -with some overlap. Please see file-level license headers for file-specific -copyright info. All parties provide their portions of the code under the -3-clause BSD license, found below. - ---- - -Copyright (C) 2018, The University of Texas at Austin -Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +This summary and the license information provided below is for reference purposes and is not intended to be a comprehensive list of all copyright notices or license terms and conditions applicable to BLAS Library. Please refer to the source code files in BLAS Library for all copyrights and licenses. +AMD copyrighted code (BSD-3-clause) +Copyright Statements +Copyright (C) 2008-2022,Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. +License Text http://spdx.org/licenses/BSD-3-Clause +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +AMD copyrighted code (MIT) +Copyright Statements +Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved. +License Text http://spdx.org/licenses/MIT +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +flame-blis v-u (BSD-3-Clause) +Copyright Statements +Copyright (C) 2017, Advanced Micro Devices, Inc. +Copyright (C) 2014, The University of Texas at Austin +License Text http://spdx.org/licenses/BSD-3-Clause +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +flame-blis v0.8.1 (BSD-3-Clause) +Attribution Statements +NOTE: Portions of this project's code are copyrighted by The University of Texas at Austin while other portions are copyrighted by +Advanced Micro Devices, Inc. with some overlap. Please see file-level license headers for file-specific copyright info. +Copyright Statements +Copyright (C) 2018, Advanced Micro Devices, Inc. +Copyright (C) 2014, The University of Texas at Austin +License Text http://spdx.org/licenses/BSD-3-Clause +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +flame-libflame v3.1 (BSD-3-Clause) +Attribution Statements +Select parts of libflame's f2c implementation were taken from: +https://github.com/juanjosegarciaripoll/f2c +which uses the following license: +Copyright (C) 1990 - 1997 by AT&T, Lucent Technologies and Bellcore. +Permission to use, copy, modify, and distribute this software +and its documentation for any purpose and without fee is hereby +granted, provided that the above copyright notice appear in all +copies and that both that the copyright notice and this +permission notice and warranty disclaimer appear in supporting +documentation, and that the names of AT&T, Bell Laboratories, +Lucent or Bellcore or any of their entities not be used in +advertising or publicity pertaining to distribution of the +software without specific, written prior permission. +AT&T, Lucent and Bellcore disclaim all warranties with regard to +this software, including all implied warranties of +merchantability and fitness. In no event shall AT&T, Lucent or +Bellcore be liable for any special, indirect or consequential +damages or any damages whatsoever resulting from loss of use, +data or profits, whether in an action of contract, negligence or +other tortious action, arising out of or in connection with the +use or performance of this software. +Copyright Statements +Copyright (C) 2014, The University of Texas at Austin +Copyright (C) 1990 - 1997 by AT&T, Lucent Technologies and Bellcore. +License Text http://spdx.org/licenses/BSD-3-Clause +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE + +google-googletest v-u (BSD-3-Clause) +Copyright Statements +Copyright 2008, Google Inc. +All rights reserved. +License Text http://spdx.org/licenses/BSD-3-Clause +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +herumi-xbyak v-u (BSD-3-Clause) +Copyright Statements +Copyright (c) 2007 MITSUNARI Shigeo All rights reserved. +License Text https://spdx.org/licenses/BSD-3-Clause.html +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +jothepro-doxygen-awesome-css v-u (MIT) +Copyright Statements +Copyright (c) 2021 - 2022 jothepro +License Text http://spdx.org/licenses/MIT +MIT License +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +lzma-sdk v-u (PD) +Copyright Statements +https://www.nuget.org/packages/LZMA-SDK +License Text +LZMA SDK is placed in the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute the original LZMA SDK code, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. + diff --git a/Makefile b/Makefile index 4c4c01ffd0..7f46c500dd 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -191,6 +191,13 @@ gen-obj-paths-from-src = $(foreach ch, $(1), \ # directories. MK_CONFIG_OBJS := $(call gen-obj-paths-from-src,$(CONFIG_SRC_SUFS),$(MK_CONFIG_SRC),$(CONFIG_PATH),$(BASE_OBJ_CONFIG_PATH)) +MK_KERNELS_LPGEMM_SRC := $(filter ./kernels/zen/lpgemm/%.c, $(MK_KERNELS_SRC)) +MK_KERNELS_LPGEMM_SRC += $(filter ./kernels/zen4/lpgemm/%.c, $(MK_KERNELS_SRC)) +MK_KERNELS_SRC := $(filter-out $(MK_KERNELS_LPGEMM_SRC),$(MK_KERNELS_SRC)) +ifeq ($(filter aocl_gemm, $(ADDON_LIST)), aocl_gemm) + MK_KERNELS_LPGEMM_OBJS := $(call gen-obj-paths-from-src,$(KERNELS_SRC_SUFS),$(MK_KERNELS_LPGEMM_SRC),$(KERNELS_PATH),$(BASE_OBJ_KERNELS_PATH)) +endif + # Generate object file paths for architecture-specific kernel source code. # We target only .c, .s, and .S files. Note that MK_KERNELS_SRC is already # limited to the kernel source corresponding to the kernel sets in @@ -220,10 +227,29 @@ MK_ADDON_KERS_SRC := $(foreach addon, $(ADDON_LIST), \ $(filter $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ $(MK_ADDON_SRC)) \ ) + +# Generate non-kernel list for all addons except aocl_gemm +# We process aocl_gemma addon separately. MK_ADDON_OTHER_SRC := $(foreach addon, $(ADDON_LIST), \ - $(filter-out $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ - $(MK_ADDON_SRC)) \ + $(if $(filter-out aocl_gemm,$(addon)), \ + $(filter-out $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC))) \ ) + +# Pick the .cpp files present in JIT folder only in the following conditions +# 1. when gcc version is older than 11.2 +# 2. when aocl_gemm addon is enabled. +ifeq ($(filter aocl_gemm, $(ADDON_LIST)), aocl_gemm) + ifeq ($(GCC_OT_11_2_0),no) + MK_AOCL_GEMM_OTHER_SRC := $(filter-out $(ADDON_PATH)/$(aocl_gemm)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) + MK_ADDON_OTHER_SRC := $(filter %.c,$(MK_AOCL_GEMM_OTHER_SRC)) + else + MK_ADDON_OTHER_SRC := $(filter-out $(ADDON_PATH)/$(aocl_gemm)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) + endif +endif + MK_ADDON_KERS_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_KERS_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) MK_ADDON_OTHER_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_OTHER_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) MK_ADDON_OBJS := $(MK_ADDON_KERS_OBJS) $(MK_ADDON_OTHER_OBJS) @@ -264,6 +290,10 @@ MK_BLIS_OBJS := $(MK_CONFIG_OBJS) \ $(MK_ADDON_OBJS) \ $(MK_SANDBOX_OBJS) +ifeq ($(filter aocl_gemm, $(ADDON_LIST)), aocl_gemm) + MK_BLIS_OBJS += $(MK_KERNELS_LPGEMM_OBJS) +endif + # Optionally filter out the BLAS and CBLAS compatibility layer object files. # This is not actually necessary, since each affected file is guarded by C # preprocessor macros, but it but prevents "empty" object files from being @@ -606,6 +636,19 @@ else endif endef +# first argument: a kernel set (name) being targeted (e.g. haswell). +# second argument: the configuration whose CFLAGS we should use in compilation. +# third argument: the kernel file suffix being considered. +define make-kernels-lpgemm-rule +$(BASE_OBJ_KERNELS_PATH)/$(1)/%.o: $(KERNELS_PATH)/$(1)/%.$(3) $(BLIS_H_FLAT) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-kernel-lpgemm-cflags-for,$(2)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-kernel-lpgemm-text-for,$(2)) + @$(CC) $(call get-kernel-lpgemm-cflags-for,$(2)) -c $$< -o $$@ +endif +endef + # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. # second argument: the C99 addon file suffix being considered. @@ -710,6 +753,10 @@ $(foreach conf, $(CONFIG_LIST), $(eval $(call make-refkern-rule,$(conf)))) $(foreach suf, $(KERNELS_SRC_SUFS), \ $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-rule,$(kset),$(call get-config-for-kset,$(kset)),$(suf))))) +ifeq ($(filter aocl_gemm, $(ADDON_LIST)), aocl_gemm) + $(foreach suf, $(KERNELS_SRC_SUFS), \ + $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-lpgemm-rule,$(kset)/lpgemm,$(call get-config-for-kset,$(kset)),$(suf))))) +endif # Instantiate the build rule for C addon files. Use the CFLAGS for the # configuration family. $(foreach suf, $(ADDON_C99_SUFS), \ @@ -850,20 +897,14 @@ else @$(RANLIB) $@ endif -# first argument: the base name of the BLAS test driver. -define make-blat-rule -$(BASE_EXE_BLASTEST_PATH)/$(1).x: $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) +$(BASE_EXE_BLASTEST_PATH)/%.x: $(BASE_OBJ_BLASTEST_PATH)/%.o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) @mkdir -p $(BASE_EXE_BLASTEST_PATH) ifeq ($(ENABLE_VERBOSE),yes) - $(LINKER) $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $$@ + $(LINKER) $< $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ else - @echo "Linking $$(@F) against '$(notdir $(BLASTEST_F2C_LIB)) $(LIBBLIS_LINK) $(LDFLAGS)'" - @$(LINKER) $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $$@ + @echo "Linking $@ against '$(notdir $(BLASTEST_F2C_LIB)) $(LIBBLIS_LINK) "$(LDFLAGS)"'" + @$(LINKER) $< $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ endif -endef - -# Instantiate the rule above for each driver file. -$(foreach name, $(BLASTEST_DRV_BASES), $(eval $(call make-blat-rule,$(name)))) # A rule to run ?blat1.x driver files. define make-run-blat1-rule @@ -933,7 +974,7 @@ $(TESTSUITE_BIN): $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) ifeq ($(ENABLE_VERBOSE),yes) $(LINKER) $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ else - @echo "Linking $@ against '$(LIBBLIS_LINK) $(LDFLAGS)'" + @echo "Linking $@ against '$(LIBBLIS_LINK) "$(LDFLAGS)"'" @$(LINKER) $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ endif @@ -1081,6 +1122,13 @@ else $(@)/$(CONFIG_DIR)/$(CONFIG_NAME)/ endif +# BLIS library in pkg-configure blis.pc.in file. +ifeq ($(THREADING_MODEL),off) +AOCLLIB := blis +else +AOCLLIB := blis-mt +endif + $(PC_SHARE_DIR_INST): $(PC_IN_FILE) $(MKDIR) $(@) ifeq ($(ENABLE_VERBOSE),no) @@ -1088,6 +1136,7 @@ ifeq ($(ENABLE_VERBOSE),no) endif $(shell cat "$(PC_IN_FILE)" \ | sed -e "s#@PACKAGE_VERSION@#$(VERSION)#g" \ + | sed -e "s#@AOCLLIB@#$(AOCLLIB)#g" \ | sed -e "s#@prefix@#$(prefix)#g" \ | sed -e "s#@exec_prefix@#$(exec_prefix)#g" \ | sed -e "s#@libdir@#$(libdir)#g" \ diff --git a/addon/CMakeLists.txt b/addon/CMakeLists.txt index 073a3fb75b..169d482be9 100644 --- a/addon/CMakeLists.txt +++ b/addon/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Writing a function that will be used to generate the required object # libraries for the required addons. @@ -59,6 +91,7 @@ function(generate_addon_targets addon_target) # in get-addon-c99flags-for ${CADDONINCFLAGS} ) + if(THREADING_MODEL STREQUAL "openmp") # Equivalent to CTHREADFLAGS in get-noopt-cflags-for target_link_libraries(${addon_target}_C99_ADDON PRIVATE OpenMP::OpenMP_C) @@ -66,10 +99,8 @@ function(generate_addon_targets addon_target) # in get-noopt-cflags-for target_compile_options(${addon_target}_C99_ADDON PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${addon_target}_C99_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_C99_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${addon_target}_C99_ADDON flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${addon_target}_C99_ADDON PROPERTIES FOLDER object-libs-targets) @@ -128,17 +159,17 @@ function(generate_addon_targets addon_target) # in get-noopt-cflags-for target_compile_options(${addon_target}_C99_KERNEL_ADDON PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${addon_target}_C99_KERNEL_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_C99_KERNEL_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${addon_target}_C99_KERNEL_ADDON flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${addon_target}_C99_KERNEL_ADDON PROPERTIES FOLDER object-libs-targets) endif() - # Collect all subdirectory paths that have at least one file with suffix in ADDON_CXX_SUFS list. - get_filepaths_with_suffixes(LOCAL_SOURCE_CXX_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${addon_target}" "${ADDON_CXX_SUFS}") + if(("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") AND (CMAKE_C_COMPILER_VERSION VERSION_LESS 11.2.0)) + # Collect all subdirectory paths that have at least one file with suffix in ADDON_CXX_SUFS list. + get_filepaths_with_suffixes(LOCAL_SOURCE_CXX_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${addon_target}" "${ADDON_CXX_SUFS}") + endif() # Only generate the object library if there is at least one source file. list(LENGTH LOCAL_SOURCE_CXX_FILES size) @@ -190,10 +221,8 @@ function(generate_addon_targets addon_target) # in get-noopt-cflags-for target_compile_options(${addon_target}_CXX_ADDON PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${addon_target}_CXX_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_CXX_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${addon_target}_CXX_ADDON flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${addon_target}_CXX_ADDON PROPERTIES FOLDER object-libs-targets) diff --git a/addon/aocl_gemm/JIT/lpgemm_jit_bf16.cpp b/addon/aocl_gemm/JIT/lpgemm_jit_bf16.cpp new file mode 100644 index 0000000000..de4b6b40c1 --- /dev/null +++ b/addon/aocl_gemm/JIT/lpgemm_jit_bf16.cpp @@ -0,0 +1,1509 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "lpgemm_jit_bf16.h" + +// push callee-save registers to stack +void bli_lpgemm_jit:: preamble() +{ + push(rbp); + push(rbx); + push(r12); + push(r13); + push(r14); + push(r15); +} + +// pop the callee-save registers before returning from function. +void bli_lpgemm_jit:: postamble() +{ + pop(r15); + pop(r14); + pop(r13); + pop(r12); + pop(rbx); + pop(rbp); + vzeroupper(); +} + +void bli_lpgemm_jit:: store_zmms_in_stack( dim_t reg_start_idx, + dim_t num_regs, + dim_t stack_off + ) +{ + for( dim_t idx = 0; idx < num_regs; idx++ ) + { + vmovups( ptr[ rsp + zmm_stack_top + stack_off + idx * 64], + Zmm( reg_start_idx + idx ) ); + } +} + +void bli_lpgemm_jit:: get_zmms_from_stack( dim_t reg_start_idx, + dim_t num_regs, + dim_t stack_off + ) +{ + for( dim_t idx = 0; idx < num_regs; idx++ ) + { + vmovups( Zmm( reg_start_idx + idx ), + ptr[ rsp + zmm_stack_top + stack_off + idx * 64] ); + } +} + +//Zero out the registers that will be used for storing accumulated values. +// For a given micro-kernel dimension MRxNR, +// considering a row-major kernel, we need (MR * (NR / num_elems per reg)) +// registers to store accumulated values. +void bli_lpgemm_jit:: reg_init( dim_t m_dim, dim_t n_dim ) +{ + vxorps( Zmm( fma_start_idx ), Zmm( fma_start_idx )); + for( dim_t m = fma_start_idx + 1; m < 32; m++ ) + { + vmovaps( Zmm( m ), Zmm( fma_start_idx ) ); + } +} + + +// This code replicates the existing bf16 kernel. +// Hence unroll factor is hardcoded to be 2. +// To-DO: Make unroll factor as an configurable parameter. +void bli_lpgemm_jit:: kernel_unroll( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + + // Broadcast elements of A matrix + vpbroadcastd( Zmm( bcst_start_idx ), ptr[ rax ] ); + + // load elements of B matrix into registers + for( dim_t n = 0; n < num_full_loads; n++ ) + vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] ); + + // In case of last load with fringe part, use mask + if( n_rem ) + vmovdqu16( Zmm( load_start_idx + num_full_loads ) + | k3 | T_z, ptr[ rbx + num_full_loads * 64 ] ); + + add( rbx, r10 ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + // broadcast elements of A matrix. + // Using 2 ZMM registers for broadcast. + if( m < ( m_dim - 1 ) ) + { + switch ( m + 1 ) + { + case 1: + case 4: + case 2: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r8 * ( m + 1 ) ] ); + break; + case 3: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r13 ] ); + break; + case 5: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r15 ] ); + break; + default: + break; + } + } + + // move to next column + if( m == ( m_dim - 1 ) ) add( rax, r9 ); + + // Generate FMA instructions. + for( dim_t n = 0; n < num_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + + vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ), + Zmm( load_start_idx + n ) ); + } + } +} + +void bli_lpgemm_jit:: k_fringe_loop( dim_t m_dim, dim_t n_dim ) +{ + + dim_t reg_num; + + // Broadcast elements of A matrix + vpbroadcastw( Zmm( bcst_start_idx ), ptr[ rax ] ); + + // load elements of B matrix into registers + for( dim_t n = 0; n < num_full_loads; n++ ) + vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] ); + + // In case of last load with fringe part, use mask + if( n_rem ) + vmovdqu16( Zmm( load_start_idx + num_full_loads ) + | k3 | T_z, ptr[ rbx + num_full_loads * 64 ] ); + + + for( dim_t m = 0; m < m_dim; m++ ) + { + if( m < ( m_dim - 1 ) ) + { + // broadcast elements of A matrix. + // Using 2 ZMM registers for broadcast. + switch ( m + 1 ) + { + case 1: + case 4: + case 2: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r8 * ( m + 1 ) ] ); + break; + case 3: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r13 ] ); + break; + case 5: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ), + ptr[ rax + r15 ] ); + break; + default: + break; + } + } + + // Generate FMA instructions. + for( dim_t n = 0; n < num_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + + vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ), + Zmm( load_start_idx + n ) ); + } + + } +} + +// Generate required number of mul instructions for scaling with alpha. +void bli_lpgemm_jit:: scale_alpha( dim_t m_dim, dim_t n_dim ) +{ + for( dim_t reg_num = fma_start_idx; reg_num < 32; reg_num++ ) + vmulps( Zmm( reg_num ), Zmm( alpha_reg ), Zmm( reg_num ) ); +} + + +// Scale C by beta and store when beta is a generic value. +void bli_lpgemm_jit:: f32_f32_beta_op( dim_t m_dim, dim_t n_dim) +{ + dim_t reg_num; + for( dim_t m = 0; m < m_dim; m++ ) + { + if( m > 0 ) add( rcx, rdi ); + + for( dim_t n = 0; n < num_full_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + + vmovups( Zmm( load_start_idx + n ) , ptr[ rcx + n * 64 ] ); + + vfmadd231ps( Zmm( reg_num ), Zmm( load_start_idx + n ), + Zmm( beta_reg ) ); + } + + // Use mask in case of n_fringe. + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + + vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z, + ptr[ rcx + num_full_loads * 64 ] ); + + vfmadd231ps( Zmm( reg_num ), + Zmm( load_start_idx + num_full_loads ), + Zmm( beta_reg ) ); + } + } +} + +void bli_lpgemm_jit:: bf16_f32_beta_op( dim_t m_dim, dim_t n_dim ) +{ + + dim_t reg_num; + mov( rcx, ptr[ rsp + stack_off_buf_downscale ] ); + mov( rax, ptr[ rsp + stack_off_postop + offsetof( lpgemm_post_op_attr, + rs_c_downscale ) ] ); + + + + // rs_c_downscale *= sizeof(bfloat16) + lea( rax, ptr[ rax * 2 ] ); + mov( rsi, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) ) + imul( rsi, rax ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) ) + // + post_op_c_j * sizeof(bfloat16) + lea( rsi, ptr[ rsi + rbx * 2 ] ); + + add( rcx, rsi ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + for( dim_t n = 0; n < num_full_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n * 32 ] ); + + // Shift left by 16 bits + vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ), + 0x10 ); + + // fma with beta + vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ), + Zmm( load_start_idx + n ) ); + } + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + + // load the bf16 elements from the downscale buffer using mask. + vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z, + ptr[rcx + num_full_loads * 32 ] ); + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + num_full_loads ), + Ymm( load_start_idx + num_full_loads ) ); + + // Shift left by 16 bits + vpslld( Zmm( load_start_idx + num_full_loads ), + Zmm( load_start_idx + num_full_loads ), 0x10 ); + + // fma with beta + vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ), + Zmm( load_start_idx + num_full_loads ) ); + } + + // move to next row + add( rcx, rax ); + } + +} + +void bli_lpgemm_jit:: clip_f32( dim_t m_dim, dim_t n_dim ) +{ + dim_t min_reg = load_start_idx; + dim_t max_reg = bcst_start_idx; + + // min reg + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] ); + vbroadcastss( Zmm( min_reg ), ptr[ rax ] ); + + // max reg + mov( rbx, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] ); + vbroadcastss( Zmm( max_reg ), ptr[ rbx ] ); + + for( dim_t m = fma_start_idx; m < 32; m++ ) + { + vmaxps( Zmm( m ), Zmm( m ), Zmm( min_reg ) ); + vminps( Zmm( m ), Zmm( m ), Zmm( max_reg ) ); + } +} + +void bli_lpgemm_jit:: bf16_f32_matrix_add( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + + // rcx = matrix ptr + mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] ); + + // rax = ldm + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] ); + mov( rax, ptr[ rax ] ); + + // ldm *= sizeof(bfloat16) + lea( rax, ptr[ rax * 2 ] ); + + mov( rsi, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) ) + imul( rsi, rax ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) ) + // + post_op_c_j * sizeof(bfloat16) + lea( rsi, ptr[ rsi + rbx * 2 ] ); + + add( rcx, rsi ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + for( dim_t n = 0; n < num_full_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n*32 ] ); + + // Shift left by 16 bits + vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ), + 0x10 ); + + vaddps( Zmm( reg_num ), Zmm( reg_num ), + Zmm( load_start_idx + n ) ); + + } + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + + // load the bf16 elements from the downscale buffer using mask. + vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z, + ptr[rcx + num_full_loads * 32 ] ); + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + num_full_loads ), + Ymm( load_start_idx + num_full_loads ) ); + + // Shift left by 16 bits + vpslld( Zmm(load_start_idx + num_full_loads ), + Zmm( load_start_idx + num_full_loads ), 0x10 ); + + vaddps( Zmm( reg_num ), Zmm( reg_num ), + Zmm( load_start_idx + num_full_loads ) ); + } + + // move to next row + add( rcx, rax ); + } +} + + +void bli_lpgemm_jit:: f32_f32_matrix_add( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + + // rcx = matrix ptr + mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] ); + // rax = ldm + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] ); + mov( rax, ptr[ rax ] ); + + // ldm *= sizeof(float) + lea( rax, ptr[ rax * 4 ] ); + + mov( rsi, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) ) + imul( rsi, rax ); + + // rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) ) + // + post_op_c_j * sizeof(float) + lea( rsi, ptr[ rsi + rbx * 4] ); + + add( rcx, rsi ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + for( dim_t n = 0; n < num_full_loads; n++) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + vmovups(Zmm( load_start_idx + n ), ptr[ rcx + n * 64 ] ); + vaddps( Zmm( reg_num ), Zmm( reg_num ), + Zmm( load_start_idx + n ) ); + } + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z, + ptr[ rcx + num_full_loads * 64 ] ); + vaddps( Zmm( reg_num ), Zmm( reg_num ), + Zmm( load_start_idx + num_full_loads ) ); + } + + // move to next row + add( rcx, rax ); + } +} +void bli_lpgemm_jit:: bias_row_major( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, c_stor_type ) ] ); + cmp( rcx, 4 ); + je( "BIAS_BF16_ROW_MAJOR", T_NEAR ); + + // postops_c_j *= sizeof(float) + lea( rbx, ptr[ rbx * 4 ] ); + add( rax, rbx ); + for( dim_t n = 0; n < num_full_loads; n++ ) + { + vmovups( Zmm( load_start_idx + n ), ptr[ rax + n * 64 ] ); + } + if( n_rem ) + { + vmovups( Zmm( load_start_idx + num_full_loads ) | k4, + ptr[ rax + num_full_loads * 64 ] ); + } + jmp( "POST_BIAS_BF16_ROW_MAJOR", T_NEAR ); + + L( "BIAS_BF16_ROW_MAJOR" ); + // postops_c_j *= sizeof(bfloat16) + lea( rbx, ptr[ rbx * 2 ] ); + add( rax, rbx ); + for( dim_t n = 0; n < num_full_loads; n++ ) + { + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rax + n * 32 ] ); + + // Shift left by 16 bits + vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ), 0x10 ); + } + if( n_rem ) + { + // load the bf16 elements from the downscale buffer using mask. + vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z, + ptr[rax + num_full_loads * 32 ] ); + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( load_start_idx + num_full_loads ), + Ymm( load_start_idx + num_full_loads ) ); + + // Shift left by 16 bits + vpslld( Zmm( load_start_idx + num_full_loads ), + Zmm( load_start_idx + num_full_loads ), 0x10 ); + } + L( "POST_BIAS_BF16_ROW_MAJOR" ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + for( dim_t n = 0; n < num_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + vaddps( Zmm( reg_num ), Zmm( reg_num ), + Zmm( load_start_idx + n ) ); + } + } +} + +void bli_lpgemm_jit:: bias_col_major( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, c_stor_type ) ] ); + cmp( rcx, 4 ); + je( "BIAS_BF16_COL_MAJOR", T_NEAR ); + + // postops_c_i *= sizeof(float) + lea( rbx, ptr[ rbx * 4 ] ); + add( rax, rbx ); + for( dim_t m = 0; m < m_dim; m++ ) + { + vbroadcastss( Zmm( alpha_reg ), ptr[ rax + m * 4 ] ); + for( dim_t n = 0; n < num_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) ); + } + } + jmp( "POST_BIAS_BF16_COL_MAJOR", T_NEAR ); + + L( "BIAS_BF16_COL_MAJOR" ); + // postops_c_i *= sizeof(bfloat16) + lea( rbx, ptr[ rbx * 2 ] ); + add( rax, rbx ); + for( dim_t m = 0; m < m_dim; m++ ) + { + vpbroadcastw( Zmm( alpha_reg ), ptr[ rax + m * 4 ] ); + + // convert from 16 bit elements to 32 bit elements + vpmovsxwd( Zmm( alpha_reg ), Ymm( alpha_reg ) ); + + // Shift left by 16 bits + vpslld( Zmm( alpha_reg ), Zmm( alpha_reg ), 0x10 ); + + for( dim_t n = 0; n < num_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) ); + } + } + L( "POST_BIAS_BF16_COL_MAJOR" ); +} + +void bli_lpgemm_jit:: relu( dim_t m_dim, dim_t n_dim ) +{ + dim_t scratch_reg = bcst_start_idx; + + vpxorq(Zmm( scratch_reg ), Zmm( scratch_reg ), Zmm( scratch_reg ) ); + + for( dim_t m = fma_start_idx; m < 32; m++ ) + { + vmaxps( Zmm( m ), Zmm( m ), Zmm( scratch_reg ) ); + } +} + +void bli_lpgemm_jit:: relu_scale( dim_t m_dim, dim_t n_dim ) +{ + dim_t zero_reg = load_start_idx; + dim_t scale_factor = bcst_start_idx; + + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] ); + vbroadcastss( Zmm( scale_factor ), ptr[ rax ] ); + vpxorq( Zmm( zero_reg ), Zmm( zero_reg ), Zmm( zero_reg ) ); + + for( dim_t m = fma_start_idx; m < 32; m++ ) + { + vcmpps( k5, Zmm( m ), Zmm( zero_reg ), 0x02 ); + vmulps( Zmm( m ) | k5, Zmm( m ), Zmm( scale_factor ) ); + } +} + +void bli_lpgemm_jit::apply_post_ops_in_high_reg_pressure + ( + const dim_t num_post_op_regs, + std::function< void( dim_t ) > op_fn + ) +{ + dim_t num_push_regs = num_post_op_regs - fma_start_idx ; + + // If number of registers required to compute pots op is more than + // registers available, then push some accum registers to stack + // and use them to compute gelu. + store_zmms_in_stack( fma_start_idx, num_push_regs, 0 ); + + dim_t post_op_start = num_push_regs > 0 ? fma_start_idx + num_push_regs + : fma_start_idx; + + // operate on non-pushed regs + for( dim_t reg = post_op_start; reg < 32; reg++ ) + { + op_fn( reg ); + } + + // Push num_push_regs number of registers from last to stack and + // replace them with the items that were pushed earlier + // and compute on them. + store_zmms_in_stack( 32 - num_push_regs, num_push_regs, + num_push_regs * 64 ); + get_zmms_from_stack( 32 - num_push_regs, num_push_regs, 0); + + for( dim_t reg = 0; reg < num_push_regs; reg++ ) + { + op_fn( 32 - num_push_regs + reg ); + } + + for( dim_t reg = 0; reg < num_push_regs; reg++ ) + vmovups( Zmm( fma_start_idx + reg ), + Zmm( 32 - num_push_regs + reg ) ); + + get_zmms_from_stack( 32 - num_push_regs, num_push_regs, + num_push_regs * 64 ); +} + +//r2 and z, q are scratch regs +//r will be passed in and out of parent function. +void bli_lpgemm_jit:: POLY_EVAL_6_AVX512( ) +{ + vmulps( Zmm( r2 ), Zmm( r ), Zmm( r ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 3) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 2) ); + + vmovups( Zmm( q ), Zmm( const2 ) ); + vfmadd231ps( Zmm( q ), Zmm( const1 ), Zmm( r ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 1) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 0) ); + + vmovups( Zmm( z ), Zmm( const2 ) ); + vfmadd231ps( Zmm( z ), Zmm( const1 ), Zmm( r ) ); + + vfmadd231ps( Zmm( z ), Zmm( r2 ), Zmm( q ) ); + + vmulps(Zmm( r2 ), Zmm( r2 ), Zmm( r2 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 5) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 4) ); + + vfmadd231ps( Zmm( const2 ), Zmm( const1 ), Zmm( r ) ); + + vfmadd231ps( Zmm( z ), Zmm( const2 ), Zmm( r2 ) ); + vmovups(Zmm( r ), Zmm( z ) ); +} + +// z, r, dn is a scratch register +// takes 'x' as input and returns 'q' to the parent +void bli_lpgemm_jit:: EXPF_AVX512() +{ + vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 0) ); + + vmulps( Zmm( z ), Zmm( x ), Zmm(const1 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(gelu_macros_off, 1) ); + + vaddps( Zmm( dn ), Zmm( z ), Zmm( const2 ) ); + + vsubps( Zmm( r ), Zmm( dn ), Zmm( const2 ) ); + vsubps( Zmm( r ), Zmm( z ), Zmm( r ) ); + + POLY_EVAL_6_AVX512(); + + vpslld( Zmm( dn ), Zmm( dn ), 0x17 ); + + vpaddd( Zmm( q ), Zmm( r ), Zmm( dn ) ); + + vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) ); + + vpbroadcastd( Zmm( const1 ), get_constant(gelu_macros_off, 2) ); + + vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 ); + + vpandd( Zmm( q ) | k5, Zmm( q ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 3) ); + + vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 ); + + vbroadcastss( Zmm( x ), get_constant(gelu_macros_off, 4) ); + + vpxord( Zmm( x ) | k5, Zmm( q ), Zmm( const2 ) ); + vmovups(Zmm( q ), Zmm( x ) ); +} + +// uses z, dn, r as scratch regs +// passes r to child macro and gets q +// takes x_tanh as input and gives back x_tanh +void bli_lpgemm_jit:: TANHF_AVX512() +{ + vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 2) ); + + mov( ebx, 0x7FFFFFFF ); + vpbroadcastd( Zmm( const2 ), ebx ); + vpandd( Zmm( x ), Zmm( x_tanh ), Zmm( const2 ) ); + + vmulps( Zmm( x ), Zmm( x ), Zmm( const1 ) ); + + EXPF_AVX512(); + + mov( eax, -1 ); + vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 4) ); + + vaddps( Zmm( z ), Zmm( q ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 5) ); + + vaddps( Zmm( r ), Zmm( z ), Zmm( const2 ) ); + + vdivps( Zmm( z ), Zmm( z ), Zmm( r ) ); + + vmulps( Zmm( z ), Zmm( z ), Zmm( const1 ) ); + + mov( eax, -2147483648 ); + vpbroadcastd( Zmm( const1 ), eax ); + + vpandd(Zmm( q ), Zmm( x_tanh ), Zmm( const1 ) ); + + vpxord( Zmm( x_tanh ), Zmm( q ), Zmm( z ) ); +} + +void bli_lpgemm_jit:: GELU_TANH_F32_AVX512_DEF(dim_t reg ) +{ + vmulps( Zmm( r2 ), Zmm( reg ), Zmm( reg ) ); + vmulps( Zmm( r2 ), Zmm( r2 ), Zmm( reg ) ); + + vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 0) ); + vmovups( Zmm( r ), Zmm( reg ) ); + vfmadd231ps( Zmm( r ), Zmm( r2 ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 1) ); + vmulps( Zmm( x_tanh ), Zmm( r ), Zmm( const2 ) ); + + TANHF_AVX512(); + + vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 6) ); + vaddps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( const2 ) ); + vmulps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( reg ) ); + + vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 3) ); + vmulps( Zmm( reg ), Zmm( x_tanh ), Zmm( const1 ) ); +} + +void bli_lpgemm_jit:: gelu_tanh( dim_t m_dim, dim_t n_dim ) +{ + apply_post_ops_in_high_reg_pressure + ( + num_gelu_regs, + std::bind + ( + &bli_lpgemm_jit::GELU_TANH_F32_AVX512_DEF, + this, + std::placeholders::_1 + ) + ); +} + +void bli_lpgemm_jit:: POLY_EVAL_HORNER_16_0_AVX512() +{ + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 15) ); + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 14) ); + + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 13) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 12) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 11) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 10) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 9) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 8) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 7 ) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 6) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 5) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 4) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 3) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 2) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 1) ); + vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) ); + + vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 0) ); + vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) ); + + vmulps( Zmm( x ), Zmm( const2 ), Zmm( r ) ); +} + +void bli_lpgemm_jit:: ERF_AVX512() +{ + mov( eax, 0x7FFFFFFF ); + vpbroadcastd( Zmm( const2 ), eax ); + vpandd( Zmm( r ), Zmm( x_erf ), Zmm( const2 ) ); + + POLY_EVAL_HORNER_16_0_AVX512(); + + vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) ); + + vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 3) ); + + vcmpps( k5, Zmm( const2 ), Zmm( r ), 0x06 ); + + vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) ); + + vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) ); + vmovups( Zmm( x ), Zmm( const1 ) ); + + + vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) ); + + vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 ); + + vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) ); + + mov( eax, ~(0x7FFFFFFF)); + vpbroadcastd( Zmm( const2 ), eax ); + + vpandd( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) ); + + vpord( Zmm( x_erf ), Zmm( x_erf ), Zmm( const1 ) ); +} + +void bli_lpgemm_jit:: GELU_ERF_F32_AVX512_DEF( dim_t reg ) +{ + vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 0) ); + vmulps( Zmm( x_erf ), Zmm( reg ), Zmm( const1 ) ); + + ERF_AVX512(); + + vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 1) ); + vaddps( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) ); + + vmulps( Zmm( x_erf ), Zmm( x_erf ), Zmm( reg ) ); + vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 2) ); + vmulps( Zmm( reg ), Zmm( x_erf ), Zmm( const2 ) ); + +} + +void bli_lpgemm_jit:: gelu_erf( dim_t m_dim, dim_t n_dim ) +{ + apply_post_ops_in_high_reg_pressure + ( + num_gelu_regs, + std::bind + ( + &bli_lpgemm_jit::GELU_ERF_F32_AVX512_DEF, + this, + std::placeholders::_1 + ) + ); +} + +void bli_lpgemm_jit::SWISH_F32_AVX512_DEF( dim_t reg ) +{ + vpxorq( Zmm( x ), Zmm( x ), Zmm( x ) ); + vfnmadd231ps( Zmm( x ), Zmm( reg ), Zmm( x_tanh ) ); + + // Input reg x and output reg q. + EXPF_AVX512(); + + vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 6) ); + vaddps( Zmm( q ), Zmm( q ), Zmm( const1 ) ); + vdivps( Zmm( reg ), Zmm( reg ), Zmm( q ) ); +} + +void bli_lpgemm_jit::swish( dim_t m_dim, dim_t n_dim ) +{ + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] ); + vbroadcastss( Zmm( x_tanh ), ptr[ rax ] ); + + apply_post_ops_in_high_reg_pressure + ( + num_gelu_regs, + std::bind + ( + &bli_lpgemm_jit::SWISH_F32_AVX512_DEF, + this, + std::placeholders::_1 + ) + ); +} + +void bli_lpgemm_jit:: store_f32( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + for( dim_t m = 0; m < m_dim; m++ ) + { + if( m > 0 ) add( rcx, rdi ); + + for( dim_t n = 0; n < num_full_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + vmovups( ptr[ rcx + n * 64 ], Zmm( reg_num ) ); + } + + // Use mask in case of n_fringe. + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + vmovups( ptr[ rcx + num_full_loads * 64 ] | k4, Zmm( reg_num ) ); + } + } +} +void bli_lpgemm_jit:: cvt_store_f32_bf16_mask( dim_t m_dim, dim_t n_dim ) +{ + dim_t reg_num; + + mov( rcx, ptr[ rsp + stack_off_buf_downscale ] ); + mov( rax, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] ); + + // rs_c_downscale *= sizeof(bfloat16) + lea( rax, ptr[rax * 2 ] ); + mov( rsi, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + mov( rbx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + + imul( rsi, rax ); + lea( rsi, ptr[ rsi + rbx * 2 ] ); + add( rcx, rsi ); + + for( dim_t m = 0; m < m_dim; m++ ) + { + for( dim_t n = 0; n < num_full_loads; n++ ) + { + reg_num = fma_start_idx + ( m * num_loads ) + n; + // convert from 32 bit elements to 16 bit elements + vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) ); + vmovdqu16( ptr[ rcx + n * 32 ], Ymm( reg_num ) ); + } + if( n_rem ) + { + reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads; + // convert from 32 bit elements to 16 bit elements + vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) ); + vmovdqu16( ptr[ rcx + num_full_loads * 32 ] | k4, Ymm( reg_num ) ); + } + // move to next row + add( rcx, rax ); + } +} + +void bli_lpgemm_jit::initialize_params( lpgemm_jit_inputs_t* params ) +{ + // params needed in kernel + // a(r14, rax), b(rbx), c(r12, rcx) podim_ters. To be stored in regs + // rs_a(r8), cs_a(r9), rs_b(r10), rs_c(rdi). + // alpha(rax), beta(rbx) values. To be pushed to stack + // m_iter(r11), ps_a(rax) values. ps_a to be pushed to stack. + // k_iter(rsi), k_left(rsi) value. To be pushed to stack. + + // load values from params struct to registers and stack + if( params->m_loop ) + { + // move address of a + mov( r14, ptr[ rdi + offsetof( lpgemm_jit_params_t, a ) ] ); + mov( r11, ptr[ rdi + offsetof( lpgemm_jit_params_t, m_iter ) ] ); + } + else + { + mov( rax, ptr[ rdi + offsetof(lpgemm_jit_params_t, a ) ] ); + } + + if( params->generate_mask ) + { + // This mask will be used to load/store bf16 elements + kmovd( k3, ptr[ rdi + offsetof( lpgemm_jit_params_t, mask16 ) ] ); + // This mask will be used to load/store f32 elements + kmovw( k4, ptr[ rdi + offsetof(lpgemm_jit_params_t, mask32 ) ] ); + } + + mov( r12, ptr[ rdi + offsetof( lpgemm_jit_params_t, c ) ] ); + mov( r8, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_a ) ] ); + mov( r9, ptr[ rdi + offsetof( lpgemm_jit_params_t, cs_a ) ] ); + mov( r10, ptr [rdi + offsetof( lpgemm_jit_params_t, rs_b ) ] ); + + + // Push all the params that will be required in later stages + // of kernel to stack. + // Pusing in order ps_a2, k_iter, k_left, alpha, beta, b + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, ps_a2 ) ] ); + mov( ptr[ rsp + stack_off_ps_a ], rbx); + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, + k_iter_before_prefetch ) ] ); + mov( ptr[ rsp + stack_off_k_iter_before_prefetch ], rbx ); + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, + k_iter_after_prefetch ) ] ); + mov( ptr[ rsp + stack_off_k_iter_after_prefetch ], rbx ); + + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, k_left ) ] ); + mov( ptr[ rsp + stack_off_k_left ], rbx ); + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, alpha ) ] ); + mov( ptr[ rsp + stack_off_alpha ], rbx ); + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, beta ) ] ); + mov( ptr[ rsp + stack_off_beta ], rbx ); + + mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, b ) ] ); + mov( ptr[ rsp + stack_off_b_ptr ], rbx ); + + // once all the params that will be required in + // later stages of kernel are pushed to stack, + // move rs_c dim_to rdi. + mov( rdi, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_c ) ] ); + + + // push all members of lpgemm_post_op_attr struct to stack. + // Since this will be passed as 2nd arg to the function, it will be in rsi + + mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rbx ); + + mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_j ) ], rcx ); + + mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, rs_c_downscale)], rbx ); + + mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, cs_c_downscale ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, cs_c_downscale)], rcx ); + + mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, buf_downscale ) ] ); + mov( ptr[ rsp + stack_off_buf_downscale ], rbx ); + + mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, is_first_k ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, is_first_k ) ], rcx ); + + mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, is_last_k ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, is_last_k ) ], rbx ); + + mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, c_stor_type ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, c_stor_type ) ], rcx ); + + mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, b_sum_offset)]); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, b_sum_offset )] , rbx ); + + mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ] ); + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ], rcx ); + + mov( rbx, ptr[ rsi + + offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ] ); + + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ], rbx ); + + // Storing the address to the head node of post-op list in stack + // It needs to be restored after every loop of m_iter + mov( ptr[ rsp + stack_off_temp_list ], rdx ); + + // initialize top of zmm stack + zmm_stack_top = stack_off_zmm_stack; +} + +void bli_lpgemm_jit:: prefetchC( dim_t m_dim, dim_t n_dim ) +{ + for( dim_t m = 0; m < m_dim; m++ ) + { + if( m > 0 ) add( rcx, rdi ); + for( dim_t n = 0; n < num_loads; n++ ) + { + prefetcht1( ptr[ rcx + n * 64 ] ); + } + } +} + +void bli_lpgemm_jit:: post_op_label_lastk_safe_jump_with_next_ptr() +{ + mov( rdx, ptr[rdx+offsetof( lpgemm_post_op, next ) ] ); + post_op_label_lastk_safe_jump(); +} +void bli_lpgemm_jit:: post_op_label_lastk_safe_jump() +{ + // check if post_ops_list_temp != NULL + cmp( rdx, 0 ); + je( "POST_OPS_6x64_DISABLE", T_NEAR ); + + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_code ) ] ); + cmp( rax, POST_OPS_DISABLE ); + je( "POST_OPS_6x64_DISABLE", T_NEAR ); + cmp( rax, POST_OPS_BIAS ) ; + je( "POST_OPS_BIAS_6x64", T_NEAR ); + cmp( rax, POST_OPS_RELU ); + je( "POST_OPS_RELU_6x64", T_NEAR ); + cmp( rax, POST_OPS_RELU_SCALE ); + je( "POST_OPS_RELU_SCALE_6x64", T_NEAR ); + cmp( rax, POST_OPS_GELU_TANH ); + je( "POST_OPS_GELU_TANH_6x64", T_NEAR ); + cmp( rax, POST_OPS_GELU_ERF ); + je( "POST_OPS_GELU_ERF_6x64", T_NEAR ); + cmp( rax, POST_OPS_CLIP ); + je( "POST_OPS_CLIP_6x64", T_NEAR ); + cmp( rax, POST_OPS_DOWNSCALE ); + je( "POST_OPS_DOWNSCALE_6x64", T_NEAR ); + cmp( rax, POST_OPS_MATRIX_ADD ); + je( "POST_OPS_MATRIX_ADD_6x64", T_NEAR ); + cmp( rax, POST_OPS_SWISH ); + je( "POST_OPS_SWISH_6x64", T_NEAR ); +} + +// Constructor +bli_lpgemm_jit:: bli_lpgemm_jit( void* buffer, size_t bufferSize ) + : CodeGenerator( bufferSize, buffer ) +{ + protect( buffer, bufferSize, PROTECT_RWE ); +} + +// Main kernel function body +void bli_lpgemm_jit::generate_kernel( lpgemm_jit_inputs_t* params ) +{ + + dim_t m_dim = params->MR; + dim_t n_dim = params->NR; + + // In kernel-function pointer array, kernels to handle n < 16 + // are stored at col-index 0. Hacking n_dim to some value 0 < value < 16 + // so masked instructions are generated. + // This will be removed when we support on-the-fly generation of kernels. + if( n_dim == 0 ) + { + n_dim = 2; + params->generate_mask = TRUE; + } + + + n_rem = n_dim % NUM_F32_ELEMS_PER_ZMM; + + // Number of loads that doesn't require mask + num_full_loads = ( n_dim / num_elems_per_reg ); + + // Number of loads in total = full loads + mask load (if required) + num_loads = ( num_full_loads ) + ( n_rem > 0 ? 1 : 0 ); + + // Total number of registers to store accumulated values. + num_fma_regs = m_dim * num_loads; + + // calculating start index for accumulation registers. + // If the kernel requires 'x' number of accumulation regs, we use the + // last 'x' ZMMs available on certain architecture. + // 31 is hardcoded here since we only support AVX-512 as of now, + // This needs to be made as a configurable parameter later. + fma_start_idx = 31 - num_fma_regs + 1; + + // If a kernel requires x registers for loads, we always use the + // first 'x' ZMM registers available for loads. + // And the immediate registers next to load regs are used for broadcast. + bcst_start_idx = load_start_idx + num_loads; + + // While scaling the accumulated registers with beta, + // load regs will be used to load C matrix, + // Hence using broadcast register to store beta value. + beta_reg = bcst_start_idx; + + + preamble(); + // add some spack in stack to store params + sub( rsp, 512 ); + // Initialize all the paramters required for execution of kernel. + // load some values to registers and push the rest of them to stack. + initialize_params( params ); + +/* register usage: + r14, rax - podim_ter for A matrix + r8 - rs_a + r9 - cs_a + r13 - 3 * rs_a + r15 - 5 * rs_a + rbx - podim_ter to B matrix, beta + r10 - rs_b + r12, rcx - podim_ter for C matrix + rdi - rs_c + r11 - m_iter + rsi - k_iter, k_left + rax - ps_a2, alpha +*/ + + + lea( rdi, ptr[ rdi * 4 ] ); // rs_c *= sizeof(float) => rs_c *= 4 + + lea( r8, ptr[ r8 * 2 ] ); // rs_a *= sizeof(dt) => rs_a *= 2 + lea( r9, ptr[ r9 * 2 ] ); // cs_a *= sizeof(dt) => cs_a *= 2 + if ( m_dim >= 4) + lea( r13, ptr[r8 + r8 * 2 ] ); // r13 = 3 * rs_a + if( m_dim >= 6 ) + lea( r15, ptr[r8 + r8 * 4 ] ); // r15 = 5 * rs_a + + lea( r10, ptr[ r10 * 2 ] ); // rs_b *= sizeof(dt) => rs_b *= 2 + + + mov( rcx, r12 ); + + if( params->m_loop ) + { + + L( "BLOOP6X64I" ); + mov( rax, r14 ); // reset rax to current upanel of a. + } + + + mov( rbx, ptr[ rsp + stack_off_b_ptr ] ); // move address of b + + + // Zero all the registers that will be used for accumulation. + reg_init( m_dim, n_dim ); + + // load k_iter + mov( rsi, ptr[ rsp + stack_off_k_iter_before_prefetch ] ); + test( rsi, rsi ); + je( "BPREFETCH", T_NEAR ); + L( "BLOOPKITER" ); + + // Main k-unroll loop + kernel_unroll( m_dim, n_dim ); + + dec( rsi ); // i -= 1 + jne("BLOOPKITER", T_NEAR ); + + L( "BPREFETCH" ); + + prefetchC( m_dim, n_dim ); + + mov( rsi, ptr[ rsp + stack_off_k_iter_after_prefetch ] ); + test( rsi, rsi ); + je( "BCONSIDKLEFT", T_NEAR ); + + L( "AFTERPREFETCH" ); + + kernel_unroll( m_dim, n_dim ); + + dec( rsi ); + jne( "AFTERPREFETCH", T_NEAR ); + + L( "BCONSIDKLEFT" ); + // load k_left + mov( rsi, ptr[ rsp + stack_off_k_left ] ); + test( rsi, rsi ); + je( "BPOSTACCUM", T_NEAR ); + + // k_fringe + k_fringe_loop( m_dim, n_dim ); + + + L( "BPOSTACCUM" ); + + // Generate alpha scaling code only when required. + if( params->alpha_scale ) + { + mov( rax, ptr[ rsp + stack_off_alpha ] ); // load address of alpha + vbroadcastss( Zmm( alpha_reg ), ptr[ rax ] ); + + scale_alpha( m_dim, n_dim ); + + } + + mov( rbx, ptr[ rsp + stack_off_beta ] ); + vbroadcastss( Xmm( beta_reg ), ptr[ rbx ] ); // load address of beta + + // Zero out a register + vxorps( Xmm( alpha_reg ), Xmm( alpha_reg ) ); + // cmp beta value with zero + vucomiss( Xmm( beta_reg ), Xmm( alpha_reg ) ); + // if beta=0, skip beta scaling + je( "BPOSTBETAOP", T_NEAR ); + + // check if buf_downscale is NULL + mov( rax, ptr[ rsp + stack_off_buf_downscale ] ); + cmp( rax, 0 ); + je( "BETAOP", T_NEAR ); + + // Check if is_first_k is 0 + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, is_first_k ) ] ); + test( rcx, rcx ); + je( "BETAOP", T_NEAR ); + + L( "DOWNSCALEBETAOP" ); + vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] ); + bf16_f32_beta_op( m_dim, n_dim ); + jmp( "BPOSTBETAOP", T_NEAR ); + + L( "BETAOP" ); + mov( rcx, r12 ); + vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] ); + f32_f32_beta_op( m_dim, n_dim ); + + L( "BPOSTBETAOP" ); + + // Check if is_last_k is 0 + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, is_last_k ) ] ); + test(rcx, rcx); + je( "POST_OPS_6x64_DISABLE", T_NEAR ); + + post_op_label_lastk_safe_jump(); + + + L( "POST_OPS_BIAS_6x64" ); + + mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] ); + mov( bl, ptr[ rax ] ); + + //check if op_args2 == 'R' + cmp( bl, 0x52 ); + je("BIAS_ROW_MAJOR", T_NEAR ); + // check if op_args2 == 'r + cmp( bl, 0x72 ); + je( "BIAS_ROW_MAJOR", T_NEAR ); + + bias_col_major( m_dim, n_dim ); + jmp( "POST_BIAS", T_NEAR ); + + L( "BIAS_ROW_MAJOR" ); + bias_row_major( m_dim, n_dim ); + + L( "POST_BIAS" ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_RELU_6x64" ); + relu( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_RELU_SCALE_6x64" ); + relu_scale( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_GELU_TANH_6x64" ); + gelu_tanh( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_GELU_ERF_6x64" ); + gelu_erf( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_CLIP_6x64" ); + clip_f32( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_DOWNSCALE_6x64" ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_MATRIX_ADD_6x64" ); + + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, c_stor_type ) ] ); + cmp( rcx, 4 ); + je( "BF16_MATADD", T_NEAR ); + f32_f32_matrix_add( m_dim, n_dim ); + jmp( "POST_MATADD", T_NEAR ); + L( "BF16_MATADD" ); + bf16_f32_matrix_add( m_dim, n_dim ); + L( "POST_MATADD" ); + + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_SWISH_6x64" ); + swish( m_dim, n_dim ); + post_op_label_lastk_safe_jump_with_next_ptr(); + + L( "POST_OPS_6x64_DISABLE" ); + + // check if buf_downscale is NULL + mov( rax, ptr[ rsp + stack_off_buf_downscale ] ); + cmp( rax, 0 ); + je( "F32_STORE", T_NEAR ); + + // Check if is_last_k is 0 + mov( rcx, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, is_last_k ) ] ); + test( rcx, rcx ); + je( "F32_STORE", T_NEAR ); + + L( "BF16_STORE" ); + //mov( rcx, ptr[rsp + stack_off_buf_downscale]); + cvt_store_f32_bf16_mask( m_dim, n_dim ); + jmp( "END", T_NEAR ); + + L( "F32_STORE" ); + mov( rcx, r12 ); + store_f32( m_dim, n_dim ); + + L( "END" ); + + if( params->m_loop ) + { + mov(rax, ptr[ rsp + stack_off_ps_a ] ); + + lea( r12, ptr[ r12 + rdi * 4 ] ); + lea( r12, ptr[ r12 + rdi * 2 ] ); // c_ii = r12 += 6*rs_c; + + lea(r14, ptr[ r14 + rax ] ); // a_ii = r14 += ps_a2 + + //add(, m_dim ); + mov( rax, ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] ); + add( rax, m_dim); + + mov( ptr[ rsp + stack_off_postop + + offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rax ); + + mov( rdx, ptr[ rsp + stack_off_temp_list ] ); + + dec(r11); + jne("BLOOP6X64I", T_NEAR); + } + + // release the space that is requested from stack + add( rsp, 512 ); + + // restore the callee-save registers. + postamble(); + + ret(); + + align(64); + L(tables); + + db(reinterpret_cast( &gelu_consts ), sizeof( gelu_consts ) ); + db(reinterpret_cast( &gelu_macros ), sizeof( gelu_macros ) ); + db(reinterpret_cast( &lpgemm_exp ), sizeof( lpgemm_exp ) ); + db(reinterpret_cast( &erf_consts ), sizeof( erf_consts ) ); + db(reinterpret_cast( &lpgemm_erf ), sizeof( lpgemm_erf ) ); + +} + +const void (* bli_lpgemm_jit:: get_function ()const)( lpgemm_jit_params_t*, + lpgemm_post_op_attr*, + lpgemm_post_op* ) +{ + return getCode(); +} + +const void* bli_lpgemm_jit:: get_code ()const +{ + return getCode(); +} +dim_t bli_lpgemm_jit:: get_size () +{ + return getSize(); +} diff --git a/addon/aocl_gemm/JIT/lpgemm_jit_bf16.h b/addon/aocl_gemm/JIT/lpgemm_jit_bf16.h new file mode 100644 index 0000000000..1b914ee7d6 --- /dev/null +++ b/addon/aocl_gemm/JIT/lpgemm_jit_bf16.h @@ -0,0 +1,198 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef JIT_BF16_H +#define JIT_BF16_H + +#include +#include +#include +#include +#include +#include "blis.h" +#include + +using namespace Xbyak; + +class bli_lpgemm_jit: public Xbyak::CodeGenerator +{ + +private : + void preamble(); + void postamble(); + void initialize_params( lpgemm_jit_inputs_t* params ); + void reg_init(dim_t m_dim, dim_t n_dim ); + void kernel_unroll( dim_t m_dim, dim_t n_dim ); + void prefetchC( dim_t m_dim, dim_t n_dim ); + void k_fringe_loop( dim_t m_dim, dim_t n_dim ); + void scale_alpha( dim_t m_dim, dim_t n_dim ); + // beta ops + void bf16_f32_beta_op( dim_t m_dim, dim_t n_dim ); + void f32_f32_beta_op( dim_t m_dim, dim_t n_dim ); + //postops + void clip_f32( dim_t m_dim, dim_t n_dim ); + void f32_f32_matrix_add( dim_t m_dim, dim_t n_dim ); + void bf16_f32_matrix_add( dim_t m_dim, dim_t n_dim ); + void bias_row_major( dim_t m_dim, dim_t n_dim ); + void bias_col_major( dim_t m_dim, dim_t n_dim ); + void relu( dim_t m_dim, dim_t n_dim ); + void relu_scale( dim_t m_dim, dim_t n_dim ); + void gelu_tanh( dim_t m_dim, dim_t n_dim ); + void POLY_EVAL_6_AVX512(); + void EXPF_AVX512(); + void TANHF_AVX512(); + void GELU_TANH_F32_AVX512_DEF( dim_t reg ); + void POLY_EVAL_HORNER_16_0_AVX512(); + void ERF_AVX512(); + void GELU_ERF_F32_AVX512_DEF( dim_t reg ); + void gelu_erf( dim_t m_dim, dim_t n_dim ); + void SWISH_F32_AVX512_DEF( dim_t reg ); + void swish( dim_t m, dim_t n ); + + void apply_post_ops_in_high_reg_pressure + ( + const dim_t num_post_op_regs, + std::function< void( dim_t ) > op_fn + ); + // C store functions + void cvt_store_f32_bf16_mask( dim_t m_dim, dim_t n_dim ); + void store_f32( dim_t m_dim, dim_t n_dim ); + + void post_op_label_lastk_safe_jump_with_next_ptr(); + void post_op_label_lastk_safe_jump(); + + + dim_t num_elems_per_reg = 64 / sizeof(float); + dim_t n_rem; + dim_t num_fma_regs; + dim_t fma_start_idx = 0; + dim_t load_start_idx = 0; + dim_t num_full_loads; + dim_t num_loads; + dim_t bcst_start_idx; + dim_t alpha_reg = fma_start_idx; + dim_t beta_reg; + + // registers used for gelu_tanh + const dim_t num_gelu_regs = 9; + const dim_t const1 = load_start_idx; + const dim_t const2 = load_start_idx+1; + const dim_t x = load_start_idx+2; + const dim_t r = load_start_idx+3; + const dim_t r2 = load_start_idx+4; + const dim_t z = load_start_idx+5; + const dim_t dn = load_start_idx+6; + const dim_t x_tanh = load_start_idx+7; + const dim_t q = load_start_idx+8; + + // registers for gelu_erf + const dim_t num_erf_regs = 5; + const dim_t x_erf = load_start_idx+4; + + // registers used for swish. Reusing the gelu_tanh registers. + const dim_t num_swish_regs = 9; + + const dim_t stack_off_ps_a = 8; + const dim_t stack_off_k_iter_before_prefetch = 16; + const dim_t stack_off_k_iter_after_prefetch = 24; + const dim_t stack_off_k_left = 32; + const dim_t stack_off_alpha = 40; + const dim_t stack_off_beta = 48; + const dim_t stack_off_b_ptr = 56; + const dim_t stack_off_postop = 64; + const dim_t stack_off_buf_downscale = stack_off_postop + + offsetof( lpgemm_post_op_attr, + buf_downscale ); + const dim_t stack_off_temp_list = stack_off_postop + + sizeof( lpgemm_post_op ); + + + const dim_t stack_off_zmm_stack = stack_off_temp_list + 8; + dim_t zmm_stack_top; + + void store_zmms_in_stack( dim_t reg_start_idx, + dim_t num_regs, + dim_t stack_off + ); + + void get_zmms_from_stack( dim_t reg_start_idx, + dim_t num_regs, + dim_t stack_off + ); + + float gelu_consts[7] = { 0.044715, 0.797884, -2, 0.5, -1, 2, 1 }; + float gelu_macros[6] = { 1.4426950408889634, 1.2582912E7, + -88.0f, 88.0f, + (float)(1.0/0.0), -2147483648 }; + + float lpgemm_exp[6] = { 1.0000000754895704, 0.6931472254087585, + 0.2402210737432219, 0.05550297297702539, + 0.009676036358193323, 0.001341000536524434 }; + + float erf_consts[4] = { 0.707107, 1.0, 0.5, 3.553f }; + + float lpgemm_erf[16] = { 1.1283793786592402, 2.5468861568875563E-5, + 0.3756169877289898, 0.004025179163741976, + 0.12947984300439994, 0.0412525204794885, + 0.03918550001070417, 0.07104542913277255, + 0.05717052146749476, 0.025310822854733135, + 0.0067305713376882076, 0.0010410692067591445, + 6.921588102382636E-5, 4.092409485758739E-6, + 1.033131746125426E-6, 5.2927177513236435E-8 }; + + + const dim_t gelu_consts_off = 0; + const dim_t gelu_macros_off = gelu_consts_off + sizeof(gelu_consts); + const dim_t lpgemm_exp_off = gelu_macros_off + sizeof(gelu_macros); + const dim_t erf_consts_off = lpgemm_exp_off + sizeof(lpgemm_exp); + const dim_t lpgemm_erf_off = erf_consts_off + sizeof(erf_consts); + + Xbyak::Address get_constant( dim_t table_off, dim_t value_off ) + { + return ptr[rip + tables + table_off + value_off * 4 ]; + } + Xbyak::Label tables; + +public: + bli_lpgemm_jit( void* buffer, size_t bufferSize ); + void generate_kernel( lpgemm_jit_inputs_t* params ); + const void (*get_function ()const)( lpgemm_jit_params_t*, + lpgemm_post_op_attr*, + lpgemm_post_op* + ); + const void *get_code ()const; + dim_t get_size (); + +}; +#endif diff --git a/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.cpp b/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.cpp new file mode 100644 index 0000000000..7b01b39a92 --- /dev/null +++ b/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.cpp @@ -0,0 +1,72 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#include "libjit_c_connector.h" +#include "blis.h" +#include "lpgemm_jit_bf16.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +static bli_lpgemm_jit *lpgemm_jit_objs[LPGEMM_BF16_MR][LPGEMM_BF16_NR]; + +void get_jit_kernel( lpgemm_jit_inputs_t *params, + void* buffer, + dim_t bufferSize + ) +{ + dim_t m_idx = ( params->MR ) % LPGEMM_BF16_MR; + dim_t n_idx = ( params->NR ) / NUM_F32_ELEMS_PER_ZMM; + lpgemm_jit_objs[m_idx][n_idx] = new bli_lpgemm_jit( buffer, bufferSize ); + lpgemm_jit_objs[m_idx][n_idx]->generate_kernel( params ); +} + +void* get_jit_code( lpgemm_jit_inputs_t *params ) +{ + dim_t m_idx = ( params->MR ) % LPGEMM_BF16_MR; + dim_t n_idx = ( params->NR ) / NUM_F32_ELEMS_PER_ZMM; + return ((void*) lpgemm_jit_objs[m_idx][n_idx]->get_code() ); +} + +dim_t get_kernel_size( lpgemm_jit_inputs_t *params ) +{ + dim_t m_idx = ( params->MR ) % LPGEMM_BF16_MR; + dim_t n_idx = ( params->NR ) / NUM_F32_ELEMS_PER_ZMM; + return lpgemm_jit_objs[m_idx][n_idx]->get_size(); +} +#ifdef __cplusplus +} +#endif diff --git a/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.h b/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.h new file mode 100644 index 0000000000..1ae0f16e3d --- /dev/null +++ b/addon/aocl_gemm/JIT/lpgemm_jit_c_connector.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LIBJIT_C_CONNECTOR_H +#define LIBJIT_C_CONNECTOR_H + +#include "blis.h" +#ifdef __cplusplus +extern "C" { +#endif + +BLIS_EXPORT_ADDON void get_jit_kernel( lpgemm_jit_inputs_t* params, + void* buffer, + dim_t bufferSize + ); + +BLIS_EXPORT_ADDON void* get_jit_code( lpgemm_jit_inputs_t *params ); +BLIS_EXPORT_ADDON dim_t get_kernel_size( lpgemm_jit_inputs_t *params ); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/addon/aocl_gemm/JIT/lpgemm_jit_typedefs.h b/addon/aocl_gemm/JIT/lpgemm_jit_typedefs.h new file mode 100644 index 0000000000..e8f426580b --- /dev/null +++ b/addon/aocl_gemm/JIT/lpgemm_jit_typedefs.h @@ -0,0 +1,78 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef JIT_TYPEDEFS_H +#define JIT_TYPEDEFS_H + +typedef struct +{ + bool m_loop; + bool alpha_scale; + int beta_scale; + dim_t MR; + dim_t NR; + bool generate_mask; +} lpgemm_jit_inputs_t; + +typedef struct { + uint64_t m; + uint64_t n; + uint64_t k; + uint64_t rs_a; + uint64_t cs_a; + uint64_t rs_b; + uint64_t cs_b; + uint64_t rs_c; + uint64_t cs_c; + bfloat16* a; + bfloat16* b; + float* c; + uint64_t ps_a2; + uint64_t m_iter; + uint64_t k_iter_before_prefetch; + uint64_t k_iter_after_prefetch; + uint64_t k_left; + float* alpha; + float* beta; + uint32_t mask16; + uint16_t mask32; +} lpgemm_jit_params_t; + +typedef enum{ + BLIS_BETA_ZERO = 0, + BLIS_BETA_ONE = 1, + BLIS_BETA_MINUS_ONE = 2, + BLIS_BETA_GEN = 3 +} beta_val; +#endif diff --git a/addon/aocl_gemm/JIT/xbyak/xbyak.h b/addon/aocl_gemm/JIT/xbyak/xbyak.h new file mode 100644 index 0000000000..0e96ff533f --- /dev/null +++ b/addon/aocl_gemm/JIT/xbyak/xbyak.h @@ -0,0 +1,3288 @@ +#pragma once +#ifndef XBYAK_XBYAK_H_ +#define XBYAK_XBYAK_H_ +/*! + @file xbyak.h + @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ + @author herumi + @url https://github.com/herumi/xbyak + @note modified new BSD license + http://opensource.org/licenses/BSD-3-Clause +*/ +#if (not +0) && !defined(XBYAK_NO_OP_NAMES) // trick to detect whether 'not' is operator or not + #define XBYAK_NO_OP_NAMES +#endif + +#include // for debug print +#include +#include +#include +#include +#ifndef NDEBUG +#include +#endif + +// #define XBYAK_DISABLE_AVX512 + +#if !defined(XBYAK_USE_MMAP_ALLOCATOR) && !defined(XBYAK_DONT_USE_MMAP_ALLOCATOR) + #define XBYAK_USE_MMAP_ALLOCATOR +#endif +#if !defined(__GNUC__) || defined(__MINGW32__) + #undef XBYAK_USE_MMAP_ALLOCATOR +#endif + +#ifdef __GNUC__ + #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) +#else + #define XBYAK_GNUC_PREREQ(major, minor) 0 +#endif + +// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. +#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ + ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) + #include + #define XBYAK_STD_UNORDERED_SET std::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap + +/* + Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using + libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). +*/ +#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#else + #include + #define XBYAK_STD_UNORDERED_SET std::set + #include + #define XBYAK_STD_UNORDERED_MAP std::map + #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap +#endif +#ifdef _WIN32 + #ifndef WIN32_LEAN_AND_MEAN + #define WIN32_LEAN_AND_MEAN + #endif + #include + #include + #ifdef _MSC_VER + #define XBYAK_TLS __declspec(thread) + #else + #define XBYAK_TLS __thread + #endif +#elif defined(__GNUC__) + #include + #include + #include + #define XBYAK_TLS __thread +#endif +#if defined(__APPLE__) && !defined(XBYAK_DONT_USE_MAP_JIT) + #define XBYAK_USE_MAP_JIT + #include + #ifndef MAP_JIT + #define MAP_JIT 0x800 + #endif +#endif +#if !defined(_MSC_VER) || (_MSC_VER >= 1600) + #include +#endif + +// MFD_CLOEXEC defined only linux 3.17 or later. +// Android wraps the memfd_create syscall from API version 30. +#if !defined(MFD_CLOEXEC) || (defined(__ANDROID__) && __ANDROID_API__ < 30) + #undef XBYAK_USE_MEMFD +#endif + +#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) + #define XBYAK64_WIN +#elif defined(__x86_64__) + #define XBYAK64_GCC +#endif +#if !defined(XBYAK64) && !defined(XBYAK32) + #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) + #define XBYAK64 + #else + #define XBYAK32 + #endif +#endif + +#if (__cplusplus >= 201103) || (defined(_MSC_VER) && _MSC_VER >= 1900) + #undef XBYAK_TLS + #define XBYAK_TLS thread_local + #define XBYAK_VARIADIC_TEMPLATE + #define XBYAK_NOEXCEPT noexcept +#else + #define XBYAK_NOEXCEPT throw() +#endif + +// require c++14 or later +// Visual Studio 2017 version 15.0 or later +// g++-6 or later +#if ((__cplusplus >= 201402L) && !(!defined(__clang__) && defined(__GNUC__) && (__GNUC__ <= 5))) || (defined(_MSC_VER) && _MSC_VER >= 1910) + #define XBYAK_CONSTEXPR constexpr +#else + #define XBYAK_CONSTEXPR +#endif + +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) /* remove inline function */ + #pragma warning(disable : 4786) /* identifier is too long */ + #pragma warning(disable : 4503) /* name is too long */ + #pragma warning(disable : 4127) /* constant expresison */ +#endif + +// disable -Warray-bounds because it may be a bug of gcc. https://gcc.gnu.org/bugzilla/show_bug.cgi?id=104603 +#if defined(__GNUC__) && !defined(__clang__) + #define XBYAK_DISABLE_WARNING_ARRAY_BOUNDS + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Warray-bounds" +#endif + +namespace Xbyak { + +enum { + DEFAULT_MAX_CODE_SIZE = 4096, + VERSION = 0x7050 /* 0xABCD = A.BC(.D) */ +}; + +#ifndef MIE_INTEGER_TYPE_DEFINED +#define MIE_INTEGER_TYPE_DEFINED +// for backward compatibility +typedef uint64_t uint64; +typedef int64_t sint64; +typedef uint32_t uint32; +typedef uint16_t uint16; +typedef uint8_t uint8; +#endif + +#ifndef MIE_ALIGN + #ifdef _MSC_VER + #define MIE_ALIGN(x) __declspec(align(x)) + #else + #define MIE_ALIGN(x) __attribute__((aligned(x))) + #endif +#endif +#ifndef MIE_PACK // for shufps + #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) +#endif + +enum { + ERR_NONE = 0, + ERR_BAD_ADDRESSING, + ERR_CODE_IS_TOO_BIG, + ERR_BAD_SCALE, + ERR_ESP_CANT_BE_INDEX, + ERR_BAD_COMBINATION, + ERR_BAD_SIZE_OF_REGISTER, + ERR_IMM_IS_TOO_BIG, + ERR_BAD_ALIGN, + ERR_LABEL_IS_REDEFINED, + ERR_LABEL_IS_TOO_FAR, + ERR_LABEL_IS_NOT_FOUND, + ERR_CODE_ISNOT_COPYABLE, + ERR_BAD_PARAMETER, + ERR_CANT_PROTECT, + ERR_CANT_USE_64BIT_DISP, + ERR_OFFSET_IS_TOO_BIG, + ERR_MEM_SIZE_IS_NOT_SPECIFIED, + ERR_BAD_MEM_SIZE, + ERR_BAD_ST_COMBINATION, + ERR_OVER_LOCAL_LABEL, // not used + ERR_UNDER_LOCAL_LABEL, + ERR_CANT_ALLOC, + ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, + ERR_BAD_PROTECT_MODE, + ERR_BAD_PNUM, + ERR_BAD_TNUM, + ERR_BAD_VSIB_ADDRESSING, + ERR_CANT_CONVERT, + ERR_LABEL_ISNOT_SET_BY_L, + ERR_LABEL_IS_ALREADY_SET_BY_L, + ERR_BAD_LABEL_STR, + ERR_MUNMAP, + ERR_OPMASK_IS_ALREADY_SET, + ERR_ROUNDING_IS_ALREADY_SET, + ERR_K0_IS_INVALID, + ERR_EVEX_IS_INVALID, + ERR_SAE_IS_INVALID, + ERR_ER_IS_INVALID, + ERR_INVALID_BROADCAST, + ERR_INVALID_OPMASK_WITH_MEMORY, + ERR_INVALID_ZERO, + ERR_INVALID_RIP_IN_AUTO_GROW, + ERR_INVALID_MIB_ADDRESS, + ERR_X2APIC_IS_NOT_SUPPORTED, + ERR_NOT_SUPPORTED, + ERR_SAME_REGS_ARE_INVALID, + ERR_INVALID_NF, + ERR_INVALID_ZU, + ERR_CANT_USE_REX2, + ERR_INVALID_DFV, + ERR_INVALID_REG_IDX, + ERR_INTERNAL // Put it at last. +}; + +inline const char *ConvertErrorToString(int err) +{ + static const char *errTbl[] = { + "none", + "bad addressing", + "code is too big", + "bad scale", + "esp can't be index", + "bad combination", + "bad size of register", + "imm is too big", + "bad align", + "label is redefined", + "label is too far", + "label is not found", + "code is not copyable", + "bad parameter", + "can't protect", + "can't use 64bit disp(use (void*))", + "offset is too big", + "MEM size is not specified", + "bad mem size", + "bad st combination", + "over local label", + "under local label", + "can't alloc", + "T_SHORT is not supported in AutoGrow", + "bad protect mode", + "bad pNum", + "bad tNum", + "bad vsib addressing", + "can't convert", + "label is not set by L()", + "label is already set by L()", + "bad label string", + "err munmap", + "opmask is already set", + "rounding is already set", + "k0 is invalid", + "evex is invalid", + "sae(suppress all exceptions) is invalid", + "er(embedded rounding) is invalid", + "invalid broadcast", + "invalid opmask with memory", + "invalid zero", + "invalid rip in AutoGrow", + "invalid mib address", + "x2APIC is not supported", + "not supported", + "same regs are invalid", + "invalid NF", + "invalid ZU", + "can't use rex2", + "invalid dfv", + "invalid reg index", + "internal error" + }; + assert(ERR_INTERNAL + 1 == sizeof(errTbl) / sizeof(*errTbl)); + return err <= ERR_INTERNAL ? errTbl[err] : "unknown err"; +} + +#ifdef XBYAK_NO_EXCEPTION +namespace local { + +inline int& GetErrorRef() { + static XBYAK_TLS int err = 0; + return err; +} + +inline void SetError(int err) { + if (local::GetErrorRef()) return; // keep the first err code + local::GetErrorRef() = err; +} + +} // local + +inline void ClearError() { + local::GetErrorRef() = 0; +} +inline int GetError() { return Xbyak::local::GetErrorRef(); } + +#define XBYAK_THROW(err) { Xbyak::local::SetError(err); return; } +#define XBYAK_THROW_RET(err, r) { Xbyak::local::SetError(err); return r; } + +#else +class Error : public std::exception { + int err_; +public: + explicit Error(int err) : err_(err) + { + if (err_ < 0 || err_ > ERR_INTERNAL) { + err_ = ERR_INTERNAL; + } + } + operator int() const { return err_; } + const char *what() const XBYAK_NOEXCEPT + { + return ConvertErrorToString(err_); + } +}; + +// dummy functions +inline void ClearError() { } +inline int GetError() { return 0; } + +inline const char *ConvertErrorToString(const Error& err) +{ + return err.what(); +} + +#define XBYAK_THROW(err) { throw Error(err); } +#define XBYAK_THROW_RET(err, r) { throw Error(err); } + +#endif + +inline void *AlignedMalloc(size_t size, size_t alignment) +{ +#ifdef __MINGW32__ + return __mingw_aligned_malloc(size, alignment); +#elif defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + void *p; + int ret = posix_memalign(&p, alignment, size); + return (ret == 0) ? p : 0; +#endif +} + +inline void AlignedFree(void *p) +{ +#ifdef __MINGW32__ + __mingw_aligned_free(p); +#elif defined(_MSC_VER) + _aligned_free(p); +#else + free(p); +#endif +} + +template +inline const To CastTo(From p) XBYAK_NOEXCEPT +{ + return (const To)(size_t)(p); +} +namespace inner { + +#ifdef _WIN32 +struct SystemInfo { + SYSTEM_INFO info; + SystemInfo() + { + GetSystemInfo(&info); + } +}; +#endif +//static const size_t ALIGN_PAGE_SIZE = 4096; +inline size_t getPageSize() +{ +#ifdef _WIN32 + static const SystemInfo si; + return si.info.dwPageSize; +#else +#ifdef __GNUC__ + static const long pageSize = sysconf(_SC_PAGESIZE); + if (pageSize > 0) { + return (size_t)pageSize; + } +#endif + return 4096; +#endif +} + +inline bool IsInDisp8(uint32_t x) { return 0xFFFFFF80 <= x || x <= 0x7F; } +inline bool IsInInt32(uint64_t x) { return ~uint64_t(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } + +inline uint32_t VerifyInInt32(uint64_t x) +{ +#if defined(XBYAK64) && !defined(__ILP32__) + if (!IsInInt32(x)) XBYAK_THROW_RET(ERR_OFFSET_IS_TOO_BIG, 0) +#endif + return static_cast(x); +} + +enum LabelMode { + LasIs, // as is + Labs, // absolute + LaddTop // (addr + top) for mov(reg, label) with AutoGrow +}; + +} // inner + +/* + custom allocator +*/ +struct Allocator { + explicit Allocator(const std::string& = "") {} // same interface with MmapAllocator + virtual uint8_t *alloc(size_t size) { return reinterpret_cast(AlignedMalloc(size, inner::getPageSize())); } + virtual void free(uint8_t *p) { AlignedFree(p); } + virtual ~Allocator() {} + /* override to return false if you call protect() manually */ + virtual bool useProtect() const { return true; } +}; + +#ifdef XBYAK_USE_MMAP_ALLOCATOR +#ifdef XBYAK_USE_MAP_JIT +namespace util { + +inline int getMacOsVersionPure() +{ + char buf[64]; + size_t size = sizeof(buf); + int err = sysctlbyname("kern.osrelease", buf, &size, NULL, 0); + if (err != 0) return 0; + char *endp; + int major = strtol(buf, &endp, 10); + if (*endp != '.') return 0; + return major; +} + +inline int getMacOsVersion() +{ + static const int version = getMacOsVersionPure(); + return version; +} + +} // util +#endif +class MmapAllocator : public Allocator { + struct Allocation { + size_t size; +#if defined(XBYAK_USE_MEMFD) + // fd_ is only used with XBYAK_USE_MEMFD. We keep the file open + // during the lifetime of each allocation in order to support + // checkpoint/restore by unprivileged users. + int fd; +#endif + }; + const std::string name_; // only used with XBYAK_USE_MEMFD + typedef XBYAK_STD_UNORDERED_MAP AllocationList; + AllocationList allocList_; +public: + explicit MmapAllocator(const std::string& name = "xbyak") : name_(name) {} + uint8_t *alloc(size_t size) + { + const size_t alignedSizeM1 = inner::getPageSize() - 1; + size = (size + alignedSizeM1) & ~alignedSizeM1; +#if defined(MAP_ANONYMOUS) + int mode = MAP_PRIVATE | MAP_ANONYMOUS; +#elif defined(MAP_ANON) + int mode = MAP_PRIVATE | MAP_ANON; +#else + #error "not supported" +#endif +#if defined(XBYAK_USE_MAP_JIT) + const int mojaveVersion = 18; + if (util::getMacOsVersion() >= mojaveVersion) mode |= MAP_JIT; +#endif + int fd = -1; +#if defined(XBYAK_USE_MEMFD) + fd = memfd_create(name_.c_str(), MFD_CLOEXEC); + if (fd != -1) { + mode = MAP_SHARED; + if (ftruncate(fd, size) != 0) { + close(fd); + XBYAK_THROW_RET(ERR_CANT_ALLOC, 0) + } + } +#endif + void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, fd, 0); + if (p == MAP_FAILED) { + if (fd != -1) close(fd); + XBYAK_THROW_RET(ERR_CANT_ALLOC, 0) + } + assert(p); + Allocation &alloc = allocList_[(uintptr_t)p]; + alloc.size = size; +#if defined(XBYAK_USE_MEMFD) + alloc.fd = fd; +#endif + return (uint8_t*)p; + } + void free(uint8_t *p) + { + if (p == 0) return; + AllocationList::iterator i = allocList_.find((uintptr_t)p); + if (i == allocList_.end()) XBYAK_THROW(ERR_BAD_PARAMETER) + if (munmap((void*)i->first, i->second.size) < 0) XBYAK_THROW(ERR_MUNMAP) +#if defined(XBYAK_USE_MEMFD) + if (i->second.fd != -1) close(i->second.fd); +#endif + allocList_.erase(i); + } +}; +#else +typedef Allocator MmapAllocator; +#endif + +class Address; +class Reg; + +struct ApxFlagNF {}; +struct ApxFlagZU {}; + +// dfv (default flags value) is or operation of these flags +static const int T_of = 8; +static const int T_sf = 4; +static const int T_zf = 2; +static const int T_cf = 1; + +class Operand { + static const uint8_t EXT8BIT = 0x20; + unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil + unsigned int kind_:10; + unsigned int bit_:14; +protected: + unsigned int zero_:1; + unsigned int mask_:3; + unsigned int rounding_:3; + unsigned int NF_:1; + unsigned int ZU_:1; // ND=ZU + void setIdx(int idx) { idx_ = idx; } +public: + enum Kind { + NONE = 0, + MEM = 1 << 0, + REG = 1 << 1, + MMX = 1 << 2, + FPU = 1 << 3, + XMM = 1 << 4, + YMM = 1 << 5, + ZMM = 1 << 6, + OPMASK = 1 << 7, + BNDREG = 1 << 8, + TMM = 1 << 9 + }; + enum Code { +#ifdef XBYAK64 + RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, + R16, R17, R18, R19, R20, R21, R22, R23, R24, R25, R26, R27, R28, R29, R30, R31, + R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, + R16D, R17D, R18D, R19D, R20D, R21D, R22D, R23D, R24D, R25D, R26D, R27D, R28D, R29D, R30D, R31D, + R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, + R16W, R17W, R18W, R19W, R20W, R21W, R22W, R23W, R24W, R25W, R26W, R27W, R28W, R29W, R30W, R31W, + R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, + R16B, R17B, R18B, R19B, R20B, R21B, R22B, R23B, R24B, R25B, R26B, R27B, R28B, R29B, R30B, R31B, + SPL = 4, BPL, SIL, DIL, +#endif + EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, + AX = 0, CX, DX, BX, SP, BP, SI, DI, + AL = 0, CL, DL, BL, AH, CH, DH, BH + }; + XBYAK_CONSTEXPR Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0), NF_(0), ZU_(0) { } + XBYAK_CONSTEXPR Operand(int idx, Kind kind, int bit, bool ext8bit = 0) + : idx_(static_cast(idx | (ext8bit ? EXT8BIT : 0))) + , kind_(kind) + , bit_(bit) + , zero_(0), mask_(0), rounding_(0), NF_(0), ZU_(0) + { + assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two + } + XBYAK_CONSTEXPR Kind getKind() const { return static_cast(kind_); } + XBYAK_CONSTEXPR int getIdx() const { return idx_ & (EXT8BIT - 1); } + XBYAK_CONSTEXPR bool hasIdxBit(int bit) const { return idx_ & (1<= 4) goto ERR; +#else + if (idx >= 32) goto ERR; + if (4 <= idx && idx < 8) idx |= EXT8BIT; +#endif + break; + case 16: + case 32: + case 64: +#ifdef XBYAK32 + if (idx >= 16) goto ERR; +#else + if (idx >= 32) goto ERR; +#endif + break; + case 128: kind = XMM; break; + case 256: kind = YMM; break; + case 512: kind = ZMM; break; + case 8192: kind = TMM; break; + } + idx_ = idx; + kind_ = kind; + bit_ = bit; + if (bit >= 128) return; // keep mask_ and rounding_ + mask_ = 0; + rounding_ = 0; + return; + } +ERR: + XBYAK_THROW(ERR_CANT_CONVERT) +} + +class Label; + +struct Reg8; +struct Reg16; +struct Reg32; +#ifdef XBYAK64 +struct Reg64; +#endif +class Reg : public Operand { +public: + XBYAK_CONSTEXPR Reg() { } + XBYAK_CONSTEXPR Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } + // convert to Reg8/Reg16/Reg32/Reg64/XMM/YMM/ZMM + Reg changeBit(int bit) const { Reg r(*this); r.setBit(bit); return r; } + Reg8 cvt8() const; + Reg16 cvt16() const; + Reg32 cvt32() const; +#ifdef XBYAK64 + Reg64 cvt64() const; +#endif + Reg operator|(const ApxFlagNF&) const { Reg r(*this); r.setNF(); return r; } + Reg operator|(const ApxFlagZU&) const { Reg r(*this); r.setZU(); return r; } +}; + +inline const Reg& Operand::getReg() const +{ + assert(!isMEM()); + return static_cast(*this); +} + +struct Reg8 : public Reg { + explicit XBYAK_CONSTEXPR Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } +}; + +struct Reg16 : public Reg { + explicit XBYAK_CONSTEXPR Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } +}; + +struct Mmx : public Reg { + explicit XBYAK_CONSTEXPR Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } +}; + +struct EvexModifierRounding { + enum { + T_RN_SAE = 1, + T_RD_SAE = 2, + T_RU_SAE = 3, + T_RZ_SAE = 4, + T_SAE = 5 + }; + explicit XBYAK_CONSTEXPR EvexModifierRounding(int rounding) : rounding(rounding) {} + int rounding; +}; +struct EvexModifierZero{ XBYAK_CONSTEXPR EvexModifierZero() {}}; + +struct Xmm : public Mmx { + explicit XBYAK_CONSTEXPR Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } + XBYAK_CONSTEXPR Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } + Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } + Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } + Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } +}; + +struct Ymm : public Xmm { + explicit XBYAK_CONSTEXPR Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } + Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Zmm : public Ymm { + explicit XBYAK_CONSTEXPR Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } + Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } +}; + +#ifdef XBYAK64 +struct Tmm : public Reg { + explicit XBYAK_CONSTEXPR Tmm(int idx = 0, Kind kind = Operand::TMM, int bit = 8192) : Reg(idx, kind, bit) { } +}; +#endif + +struct Opmask : public Reg { + explicit XBYAK_CONSTEXPR Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} +}; + +struct BoundsReg : public Reg { + explicit XBYAK_CONSTEXPR BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} +}; + +templateT operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } +templateT operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } +templateT operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } + +struct Fpu : public Reg { + explicit XBYAK_CONSTEXPR Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } +}; + +struct Reg32e : public Reg { + explicit XBYAK_CONSTEXPR Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} + Reg32e operator|(const ApxFlagNF&) const { Reg32e r(*this); r.setNF(); return r; } + Reg32e operator|(const ApxFlagZU&) const { Reg32e r(*this); r.setZU(); return r; } +}; +struct Reg32 : public Reg32e { + explicit XBYAK_CONSTEXPR Reg32(int idx = 0) : Reg32e(idx, 32) {} +}; +#ifdef XBYAK64 +struct Reg64 : public Reg32e { + explicit XBYAK_CONSTEXPR Reg64(int idx = 0) : Reg32e(idx, 64) {} +}; +struct RegRip { + int64_t disp_; + const Label* label_; + bool isAddr_; + explicit XBYAK_CONSTEXPR RegRip(int64_t disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} + friend const RegRip operator+(const RegRip& r, int disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, int64_t disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int64_t disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, const Label& label) { + if (r.label_ || r.isAddr_) XBYAK_THROW_RET(ERR_BAD_ADDRESSING, RegRip()); + return RegRip(r.disp_, &label); + } + friend const RegRip operator+(const RegRip& r, const void *addr) { + if (r.label_ || r.isAddr_) XBYAK_THROW_RET(ERR_BAD_ADDRESSING, RegRip()); + return RegRip(r.disp_ + (int64_t)addr, 0, true); + } +}; +#endif + +inline Reg8 Reg::cvt8() const +{ + Reg r = changeBit(8); return Reg8(r.getIdx(), r.isExt8bit()); +} + +inline Reg16 Reg::cvt16() const +{ + return Reg16(changeBit(16).getIdx()); +} + +inline Reg32 Reg::cvt32() const +{ + return Reg32(changeBit(32).getIdx()); +} + +#ifdef XBYAK64 +inline Reg64 Reg::cvt64() const +{ + return Reg64(changeBit(64).getIdx()); +} +#endif + +#ifndef XBYAK_DISABLE_SEGMENT +// not derived from Reg +class Segment { + int idx_; +public: + enum { + es, cs, ss, ds, fs, gs + }; + explicit XBYAK_CONSTEXPR Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } + int getIdx() const { return idx_; } + const char *toString() const + { + static const char tbl[][3] = { + "es", "cs", "ss", "ds", "fs", "gs" + }; + return tbl[idx_]; + } +}; +#endif + +class RegExp { +public: +#ifdef XBYAK64 + enum { i32e = 32 | 64 }; +#else + enum { i32e = 32 }; +#endif + XBYAK_CONSTEXPR RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } + XBYAK_CONSTEXPR RegExp(const Reg& r, int scale = 1) + : scale_(scale) + , disp_(0) + { + if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM|Reg::TMM)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + if (scale == 0) return; + if (scale != 1 && scale != 2 && scale != 4 && scale != 8) XBYAK_THROW(ERR_BAD_SCALE) + if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index + index_ = r; + } else { + base_ = r; + } + } + bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } + RegExp optimize() const + { + RegExp exp = *this; + // [reg * 2] => [reg + reg] + if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { + exp.base_ = index_; + exp.scale_ = 1; + } + return exp; + } + bool operator==(const RegExp& rhs) const + { + return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; + } + const Reg& getBase() const { return base_; } + const Reg& getIndex() const { return index_; } + int getScale() const { return scale_; } + size_t getDisp() const { return disp_; } + XBYAK_CONSTEXPR void verify() const + { + if (base_.getBit() >= 128) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + if (index_.getBit() && index_.getBit() <= 64) { + if (index_.getIdx() == Operand::ESP) XBYAK_THROW(ERR_ESP_CANT_BE_INDEX) + if (base_.getBit() && base_.getBit() != index_.getBit()) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + } + } + friend RegExp operator+(const RegExp& a, const RegExp& b); + friend RegExp operator-(const RegExp& e, size_t disp); +private: + /* + [base_ + index_ * scale_ + disp_] + base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm + */ + Reg base_; + Reg index_; + int scale_; + size_t disp_; +}; + +inline RegExp operator+(const RegExp& a, const RegExp& b) +{ + if (a.index_.getBit() && b.index_.getBit()) XBYAK_THROW_RET(ERR_BAD_ADDRESSING, RegExp()) + RegExp ret = a; + if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } + if (b.base_.getBit()) { + if (ret.base_.getBit()) { + if (ret.index_.getBit()) XBYAK_THROW_RET(ERR_BAD_ADDRESSING, RegExp()) + // base + base => base + index * 1 + ret.index_ = b.base_; + // [reg + esp] => [esp + reg] + if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); + ret.scale_ = 1; + } else { + ret.base_ = b.base_; + } + } + ret.disp_ += b.disp_; + return ret; +} +inline RegExp operator*(const Reg& r, int scale) +{ + return RegExp(r, scale); +} +inline RegExp operator*(int scale, const Reg& r) +{ + return r * scale; +} +inline RegExp operator-(const RegExp& e, size_t disp) +{ + RegExp ret = e; + ret.disp_ -= disp; + return ret; +} + +// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) +void *const AutoGrow = (void*)1; //-V566 +void *const DontSetProtectRWE = (void*)2; //-V566 + +class CodeArray { + enum Type { + USER_BUF = 1, // use userPtr(non alignment, non protect) + ALLOC_BUF, // use new(alignment, protect) + AUTO_GROW // automatically move and grow memory if necessary + }; + CodeArray(const CodeArray& rhs); + void operator=(const CodeArray&); + bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } + struct AddrInfo { + size_t codeOffset; // position to write + size_t jmpAddr; // value to write + int jmpSize; // size of jmpAddr + inner::LabelMode mode; + AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) + : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} + uint64_t getVal(const uint8_t *top) const + { + uint64_t disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); + if (jmpSize == 4) disp = inner::VerifyInInt32(disp); + return disp; + } + }; + typedef std::list AddrInfoList; + AddrInfoList addrInfoList_; + const Type type_; +#ifdef XBYAK_USE_MMAP_ALLOCATOR + MmapAllocator defaultAllocator_; +#else + Allocator defaultAllocator_; +#endif + Allocator *alloc_; +protected: + size_t maxSize_; + uint8_t *top_; + size_t size_; + bool isCalledCalcJmpAddress_; + + bool useProtect() const { return alloc_->useProtect(); } + /* + allocate new memory and copy old data to the new area + */ + void growMemory() + { + const size_t newSize = (std::max)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); + uint8_t *newTop = alloc_->alloc(newSize); + if (newTop == 0) XBYAK_THROW(ERR_CANT_ALLOC) + for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; + alloc_->free(top_); + top_ = newTop; + maxSize_ = newSize; + } + /* + calc jmp address for AutoGrow mode + */ + void calcJmpAddress() + { + if (isCalledCalcJmpAddress_) return; + for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { + uint64_t disp = i->getVal(top_); + rewrite(i->codeOffset, disp, i->jmpSize); + } + isCalledCalcJmpAddress_ = true; + } +public: + enum ProtectMode { + PROTECT_RW = 0, // read/write + PROTECT_RWE = 1, // read/write/exec + PROTECT_RE = 2 // read/exec + }; + explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) + : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) + , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) + , maxSize_(maxSize) + , top_(type_ == USER_BUF ? reinterpret_cast(userPtr) : alloc_->alloc((std::max)(maxSize, 1))) + , size_(0) + , isCalledCalcJmpAddress_(false) + { + if (maxSize_ > 0 && top_ == 0) XBYAK_THROW(ERR_CANT_ALLOC) + if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { + alloc_->free(top_); + XBYAK_THROW(ERR_CANT_PROTECT) + } + } + virtual ~CodeArray() + { + if (isAllocType()) { + if (useProtect()) setProtectModeRW(false); + alloc_->free(top_); + } + } + bool setProtectMode(ProtectMode mode, bool throwException = true) + { + bool isOK = protect(top_, maxSize_, mode); + if (isOK) return true; + if (throwException) XBYAK_THROW_RET(ERR_CANT_PROTECT, false) + return false; + } + bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } + bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } + void resetSize() + { + size_ = 0; + addrInfoList_.clear(); + isCalledCalcJmpAddress_ = false; + } + void db(int code) + { + if (size_ >= maxSize_) { + if (type_ == AUTO_GROW) { + growMemory(); + } else { + XBYAK_THROW(ERR_CODE_IS_TOO_BIG) + } + } + top_[size_++] = static_cast(code); + } + void db(const uint8_t *code, size_t codeSize) + { + for (size_t i = 0; i < codeSize; i++) db(code[i]); + } + void db(uint64_t code, size_t codeSize) + { + if (codeSize > 8) XBYAK_THROW(ERR_BAD_PARAMETER) + for (size_t i = 0; i < codeSize; i++) db(static_cast(code >> (i * 8))); + } + void dw(uint32_t code) { db(code, 2); } + void dd(uint32_t code) { db(code, 4); } + void dq(uint64_t code) { db(code, 8); } + const uint8_t *getCode() const { return top_; } + template + const F getCode() const { return reinterpret_cast(top_); } + const uint8_t *getCurr() const { return &top_[size_]; } + template + const F getCurr() const { return reinterpret_cast(&top_[size_]); } + size_t getSize() const { return size_; } + void setSize(size_t size) + { + if (size > maxSize_) XBYAK_THROW(ERR_OFFSET_IS_TOO_BIG) + size_ = size; + } + void dump() const + { + const uint8_t *p = getCode(); + size_t bufSize = getSize(); + size_t remain = bufSize; + for (int i = 0; i < 4; i++) { + size_t disp = 16; + if (remain < 16) { + disp = remain; + } + for (size_t j = 0; j < 16; j++) { + if (j < disp) { + printf("%02X", p[i * 16 + j]); + } + } + putchar('\n'); + remain -= disp; + if (remain == 0) { + break; + } + } + } + /* + @param offset [in] offset from top + @param disp [in] offset from the next of jmp + @param size [in] write size(1, 2, 4, 8) + */ + void rewrite(size_t offset, uint64_t disp, size_t size) + { + assert(offset < maxSize_); + if (size != 1 && size != 2 && size != 4 && size != 8) XBYAK_THROW(ERR_BAD_PARAMETER) + uint8_t *const data = top_ + offset; + for (size_t i = 0; i < size; i++) { + data[i] = static_cast(disp >> (i * 8)); + } + } + void save(size_t offset, size_t val, int size, inner::LabelMode mode) + { + addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); + } + bool isAutoGrow() const { return type_ == AUTO_GROW; } + bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } + /** + change exec permission of memory + @param addr [in] buffer address + @param size [in] buffer size + @param protectMode [in] mode(RW/RWE/RE) + @return true(success), false(failure) + */ + static inline bool protect(const void *addr, size_t size, int protectMode) + { +#if defined(_WIN32) + const DWORD c_rw = PAGE_READWRITE; + const DWORD c_rwe = PAGE_EXECUTE_READWRITE; + const DWORD c_re = PAGE_EXECUTE_READ; + DWORD mode; +#else + const int c_rw = PROT_READ | PROT_WRITE; + const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; + const int c_re = PROT_READ | PROT_EXEC; + int mode; +#endif + switch (protectMode) { + case PROTECT_RW: mode = c_rw; break; + case PROTECT_RWE: mode = c_rwe; break; + case PROTECT_RE: mode = c_re; break; + default: + return false; + } +#if defined(_WIN32) + DWORD oldProtect; + return VirtualProtect(const_cast(addr), size, mode, &oldProtect) != 0; +#elif defined(__GNUC__) + size_t pageSize = sysconf(_SC_PAGESIZE); + size_t iaddr = reinterpret_cast(addr); + size_t roundAddr = iaddr & ~(pageSize - static_cast(1)); + return mprotect(reinterpret_cast(roundAddr), size + (iaddr - roundAddr), mode) == 0; +#else + return true; +#endif + } + /** + get aligned memory pointer + @param addr [in] address + @param alignedSize [in] power of two + @return aligned addr by alingedSize + */ + static inline uint8_t *getAlignedAddress(uint8_t *addr, size_t alignedSize = 16) + { + return reinterpret_cast((reinterpret_cast(addr) + alignedSize - 1) & ~(alignedSize - static_cast(1))); + } +}; + +class Address : public Operand { +public: + enum Mode { + M_ModRM, + M_64bitDisp, + M_rip, + M_ripAddr + }; + XBYAK_CONSTEXPR Address(uint32_t sizeBit, bool broadcast, const RegExp& e) + : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), immSize(0), disp8N(0), permitVsib(false), broadcast_(broadcast), optimize_(true) + { + e_.verify(); + } +#ifdef XBYAK64 + explicit XBYAK_CONSTEXPR Address(size_t disp) + : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), immSize(0), disp8N(0), permitVsib(false), broadcast_(false), optimize_(true) { } + XBYAK_CONSTEXPR Address(uint32_t sizeBit, bool broadcast, const RegRip& addr) + : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), immSize(0), disp8N(0), permitVsib(false), broadcast_(broadcast), optimize_(true) { } +#endif + RegExp getRegExp() const + { + return optimize_ ? e_.optimize() : e_; + } + Address cloneNoOptimize() const { Address addr = *this; addr.optimize_ = false; return addr; } + Mode getMode() const { return mode_; } + bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } + bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax + size_t getDisp() const { return e_.getDisp(); } + bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset + bool isBroadcast() const { return broadcast_; } + bool hasRex2() const { return e_.getBase().hasRex2() || e_.getIndex().hasRex2(); } + const Label* getLabel() const { return label_; } + bool operator==(const Address& rhs) const + { + return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && immSize == rhs.immSize && disp8N == rhs.disp8N && permitVsib == rhs.permitVsib && broadcast_ == rhs.broadcast_ && optimize_ == rhs.optimize_; + } + bool operator!=(const Address& rhs) const { return !operator==(rhs); } + bool isVsib() const { return e_.isVsib(); } +private: + RegExp e_; + const Label* label_; + Mode mode_; +public: + int immSize; // the size of immediate value of nmemonics (0, 1, 2, 4) + int disp8N; // 0(normal), 1(force disp32), disp8N = {2, 4, 8} + bool permitVsib; +private: + bool broadcast_; + bool optimize_; +}; + +inline const Address& Operand::getAddress() const +{ + assert(isMEM()); + return static_cast(*this); +} +inline Address Operand::getAddress(int immSize) const +{ + Address addr = getAddress(); + addr.immSize = immSize; + return addr; +} + +inline bool Operand::operator==(const Operand& rhs) const +{ + if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); + return isEqualIfNotInherited(rhs); +} + +inline XBYAK_CONSTEXPR bool Operand::hasRex2() const +{ + return (isREG() && isExtIdx2()) || (isMEM() && static_cast(*this).hasRex2()); +} + +class AddressFrame { + void operator=(const AddressFrame&); + AddressFrame(const AddressFrame&); +public: + const uint32_t bit_; + const bool broadcast_; + explicit XBYAK_CONSTEXPR AddressFrame(uint32_t bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } + Address operator[](const RegExp& e) const + { + return Address(bit_, broadcast_, e); + } + Address operator[](const void *disp) const + { + return Address(bit_, broadcast_, RegExp(reinterpret_cast(disp))); + } +#ifdef XBYAK64 + Address operator[](uint64_t disp) const { return Address(disp); } + Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } +#endif +}; + +struct JmpLabel { + size_t endOfJmp; /* offset from top to the end address of jmp */ + int jmpSize; + inner::LabelMode mode; + size_t disp; // disp for [rip + disp] + explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) + : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) + { + } +}; + +class LabelManager; + +class Label { + mutable LabelManager *mgr; + mutable int id; + friend class LabelManager; +public: + Label() : mgr(0), id(0) {} + Label(const Label& rhs); + Label& operator=(const Label& rhs); + ~Label(); + void clear() { mgr = 0; id = 0; } + int getId() const { return id; } + const uint8_t *getAddress() const; + + // backward compatibility + static inline std::string toStr(int num) + { + char buf[16]; +#if defined(_MSC_VER) && (_MSC_VER < 1900) + _snprintf_s +#else + snprintf +#endif + (buf, sizeof(buf), ".%08x", num); + return buf; + } +}; + +class LabelManager { + // for string label + struct SlabelVal { + size_t offset; + SlabelVal(size_t offset) : offset(offset) {} + }; + typedef XBYAK_STD_UNORDERED_MAP SlabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP SlabelUndefList; + struct SlabelState { + SlabelDefList defList; + SlabelUndefList undefList; + }; + typedef std::list StateList; + // for Label class + struct ClabelVal { + ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} + size_t offset; + int refCount; + }; + typedef XBYAK_STD_UNORDERED_MAP ClabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP ClabelUndefList; + typedef XBYAK_STD_UNORDERED_SET LabelPtrList; + + CodeArray *base_; + // global : stateList_.front(), local : stateList_.back() + StateList stateList_; + mutable int labelId_; + ClabelDefList clabelDefList_; + ClabelUndefList clabelUndefList_; + LabelPtrList labelPtrList_; + + int getId(const Label& label) const + { + if (label.id == 0) label.id = labelId_++; + return label.id; + } + template + void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) + { + // add label + typename DefList::value_type item(labelId, addrOffset); + std::pair ret = defList.insert(item); + if (!ret.second) XBYAK_THROW(ERR_LABEL_IS_REDEFINED) + // search undefined label + for (;;) { + typename UndefList::iterator itr = undefList.find(labelId); + if (itr == undefList.end()) break; + const JmpLabel *jmp = &itr->second; + const size_t offset = jmp->endOfJmp - jmp->jmpSize; + size_t disp; + if (jmp->mode == inner::LaddTop) { + disp = addrOffset; + } else if (jmp->mode == inner::Labs) { + disp = size_t(base_->getCurr()); + } else { + disp = addrOffset - jmp->endOfJmp + jmp->disp; +#ifdef XBYAK64 + if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) XBYAK_THROW(ERR_OFFSET_IS_TOO_BIG) +#endif + if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32_t)disp)) XBYAK_THROW(ERR_LABEL_IS_TOO_FAR) + } + if (base_->isAutoGrow()) { + base_->save(offset, disp, jmp->jmpSize, jmp->mode); + } else { + base_->rewrite(offset, disp, jmp->jmpSize); + } + undefList.erase(itr); + } + } + template + bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const + { + typename DefList::const_iterator i = defList.find(label); + if (i == defList.end()) return false; + *offset = i->second.offset; + return true; + } + friend class Label; + void incRefCount(int id, Label *label) + { + clabelDefList_[id].refCount++; + labelPtrList_.insert(label); + } + void decRefCount(int id, Label *label) + { + labelPtrList_.erase(label); + ClabelDefList::iterator i = clabelDefList_.find(id); + if (i == clabelDefList_.end()) return; + if (i->second.refCount == 1) { + clabelDefList_.erase(id); + } else { + --i->second.refCount; + } + } + template + bool hasUndefinedLabel_inner(const T& list) const + { +#ifndef NDEBUG + for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { + std::cerr << "undefined label:" << i->first << std::endl; + } +#endif + return !list.empty(); + } + // detach all labels linked to LabelManager + void resetLabelPtrList() + { + for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { + (*i)->clear(); + } + labelPtrList_.clear(); + } +public: + LabelManager() + { + reset(); + } + ~LabelManager() + { + resetLabelPtrList(); + } + void reset() + { + base_ = 0; + labelId_ = 1; + stateList_.clear(); + stateList_.push_back(SlabelState()); + stateList_.push_back(SlabelState()); + clabelDefList_.clear(); + clabelUndefList_.clear(); + resetLabelPtrList(); + } + void enterLocal() + { + stateList_.push_back(SlabelState()); + } + void leaveLocal() + { + if (stateList_.size() <= 2) XBYAK_THROW(ERR_UNDER_LOCAL_LABEL) + if (hasUndefinedLabel_inner(stateList_.back().undefList)) XBYAK_THROW(ERR_LABEL_IS_NOT_FOUND) + stateList_.pop_back(); + } + void set(CodeArray *base) { base_ = base; } + void defineSlabel(std::string label) + { + if (label == "@b" || label == "@f") XBYAK_THROW(ERR_BAD_LABEL_STR) + if (label == "@@") { + SlabelDefList& defList = stateList_.front().defList; + SlabelDefList::iterator i = defList.find("@f"); + if (i != defList.end()) { + defList.erase(i); + label = "@b"; + } else { + i = defList.find("@b"); + if (i != defList.end()) { + defList.erase(i); + } + label = "@f"; + } + } + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + define_inner(st.defList, st.undefList, label, base_->getSize()); + } + void defineClabel(Label& label) + { + define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); + label.mgr = this; + labelPtrList_.insert(&label); + } + void assign(Label& dst, const Label& src) + { + ClabelDefList::const_iterator i = clabelDefList_.find(src.id); + if (i == clabelDefList_.end()) XBYAK_THROW(ERR_LABEL_ISNOT_SET_BY_L) + define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); + dst.mgr = this; + labelPtrList_.insert(&dst); + } + bool getOffset(size_t *offset, std::string& label) const + { + const SlabelDefList& defList = stateList_.front().defList; + if (label == "@b") { + if (defList.find("@f") != defList.end()) { + label = "@f"; + } else if (defList.find("@b") == defList.end()) { + XBYAK_THROW_RET(ERR_LABEL_IS_NOT_FOUND, false) + } + } else if (label == "@f") { + if (defList.find("@f") != defList.end()) { + label = "@b"; + } + } + const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + return getOffset_inner(st.defList, offset, label); + } + bool getOffset(size_t *offset, const Label& label) const + { + return getOffset_inner(clabelDefList_, offset, getId(label)); + } + void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) + { + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + st.undefList.insert(SlabelUndefList::value_type(label, jmp)); + } + void addUndefinedLabel(const Label& label, const JmpLabel& jmp) + { + clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); + } + bool hasUndefSlabel() const + { + for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { + if (hasUndefinedLabel_inner(i->undefList)) return true; + } + return false; + } + bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } + const uint8_t *getCode() const { return base_->getCode(); } + bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } +}; + +inline Label::Label(const Label& rhs) +{ + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); +} +inline Label& Label::operator=(const Label& rhs) +{ + if (id) XBYAK_THROW_RET(ERR_LABEL_IS_ALREADY_SET_BY_L, *this) + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); + return *this; +} +inline Label::~Label() +{ + if (id && mgr) mgr->decRefCount(id, this); +} +inline const uint8_t* Label::getAddress() const +{ + if (mgr == 0 || !mgr->isReady()) return 0; + size_t offset; + if (!mgr->getOffset(&offset, *this)) return 0; + return mgr->getCode() + offset; +} + +typedef enum { + DefaultEncoding, + VexEncoding, + EvexEncoding +} PreferredEncoding; + +class CodeGenerator : public CodeArray { +public: + enum LabelType { + T_SHORT, + T_NEAR, + T_FAR, // far jump + T_AUTO // T_SHORT if possible + }; +private: + CodeGenerator operator=(const CodeGenerator&); // don't call +#ifdef XBYAK64 + enum { i32e = 32 | 64, BIT = 64 }; + static const uint64_t dummyAddr = uint64_t(0x1122334455667788ull); + typedef Reg64 NativeReg; +#else + enum { i32e = 32, BIT = 32 }; + static const size_t dummyAddr = 0x12345678; + typedef Reg32 NativeReg; +#endif + // (XMM, XMM|MEM) + static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isXMM() || op2.isMEM()); + } + // (MMX, MMX|MEM) or (XMM, XMM|MEM) + static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) + { + return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); + } + // (XMM, MMX|MEM) + static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isMMX() || op2.isMEM()); + } + // (MMX, XMM|MEM) + static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isMMX() && (op2.isXMM() || op2.isMEM()); + } + // (XMM, REG32|MEM) + static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); + } + // (REG32, XMM|MEM) + static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); + } + // (REG32, REG32|MEM) + static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); + } + static inline bool isValidSSE(const Operand& op1) + { + // SSE instructions do not support XMM16 - XMM31 + return !(op1.isXMM() && op1.getIdx() >= 16); + } + static inline uint8_t rexRXB(int bit, int bit3, const Reg& r, const Reg& b, const Reg& x = Reg()) + { + int v = bit3 ? 8 : 0; + if (r.hasIdxBit(bit)) v |= 4; + if (x.hasIdxBit(bit)) v |= 2; + if (b.hasIdxBit(bit)) v |= 1; + return uint8_t(v); + } + void rex2(int bit3, int rex4bit, const Reg& r, const Reg& b, const Reg& x = Reg()) + { + db(0xD5); + db((rexRXB(4, bit3, r, b, x) << 4) | rex4bit); + } + // return true if rex2 is selected + bool rex(const Operand& op1, const Operand& op2 = Operand(), uint64_t type = 0) + { + if (op1.getNF() | op2.getNF()) XBYAK_THROW_RET(ERR_INVALID_NF, false) + if (op1.getZU() | op2.getZU()) XBYAK_THROW_RET(ERR_INVALID_ZU, false) + uint8_t rex = 0; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) std::swap(p1, p2); + if (p1->isMEM()) XBYAK_THROW_RET(ERR_BAD_COMBINATION, false) + // except movsx(16bit, 32/64bit) + bool p66 = (op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e)); + if ((type & T_66) || p66) db(0x66); + if (type & T_F2) { + db(0xF2); + } + if (type & T_F3) { + db(0xF3); + } + bool is0F = type & T_0F; + if (p2->isMEM()) { + const Reg& r = *static_cast(p1); + const Address& addr = p2->getAddress(); + const RegExp e = addr.getRegExp(); + const Reg& base = e.getBase(); + const Reg& idx = e.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + rex = rexRXB(3, r.isREG(64), r, base, idx); + if (r.hasRex2() || addr.hasRex2()) { + if (type & (T_0F38|T_0F3A)) XBYAK_THROW_RET(ERR_CANT_USE_REX2, false) + rex2(is0F, rex, r, base, idx); + return true; + } + if (rex || r.isExt8bit()) rex |= 0x40; + } else { + const Reg& r1 = static_cast(op1); + const Reg& r2 = static_cast(op2); + // ModRM(reg, base); + rex = rexRXB(3, r1.isREG(64) || r2.isREG(64), r2, r1); + if (r1.hasRex2() || r2.hasRex2()) { + if (type & (T_0F38|T_0F3A)) XBYAK_THROW_RET(ERR_CANT_USE_REX2, 0) + rex2(is0F, rex, r2, r1); + return true; + } + if (rex || r1.isExt8bit() || r2.isExt8bit()) rex |= 0x40; + } + if (rex) db(rex); + return false; + } + // @@@begin of avx_type_def.h + static const uint64_t T_NONE = 0ull; + // low 3 bit + static const uint64_t T_N1 = 1ull; + static const uint64_t T_N2 = 2ull; + static const uint64_t T_N4 = 3ull; + static const uint64_t T_N8 = 4ull; + static const uint64_t T_N16 = 5ull; + static const uint64_t T_N32 = 6ull; + static const uint64_t T_NX_MASK = 7ull; + static const uint64_t T_DUP = T_NX_MASK;//1 << 4, // N = (8, 32, 64) + static const uint64_t T_N_VL = 1ull << 3; // N * (1, 2, 4) for VL + static const uint64_t T_APX = 1ull << 4; + static const uint64_t T_66 = 1ull << 5; // pp = 1 + static const uint64_t T_F3 = 1ull << 6; // pp = 2 + static const uint64_t T_ER_R = 1ull << 7; // reg{er} + static const uint64_t T_0F = 1ull << 8; + static const uint64_t T_0F38 = 1ull << 9; + static const uint64_t T_0F3A = 1ull << 10; + static const uint64_t T_L0 = 1ull << 11; + static const uint64_t T_L1 = 1ull << 12; + static const uint64_t T_W0 = 1ull << 13; + static const uint64_t T_W1 = 1ull << 14; + static const uint64_t T_EW0 = 1ull << 15; + static const uint64_t T_EW1 = 1ull << 16; + static const uint64_t T_YMM = 1ull << 17; // support YMM, ZMM + static const uint64_t T_EVEX = 1ull << 18; + static const uint64_t T_ER_X = 1ull << 19; // xmm{er} + static const uint64_t T_ER_Y = 1ull << 20; // ymm{er} + static const uint64_t T_ER_Z = 1ull << 21; // zmm{er} + static const uint64_t T_SAE_X = 1ull << 22; // xmm{sae} + static const uint64_t T_SAE_Y = 1ull << 23; // ymm{sae} + static const uint64_t T_SAE_Z = 1ull << 24; // zmm{sae} + static const uint64_t T_MUST_EVEX = 1ull << 25; // contains T_EVEX + static const uint64_t T_B32 = 1ull << 26; // m32bcst + static const uint64_t T_B64 = 1ull << 27; // m64bcst + static const uint64_t T_B16 = T_B32 | T_B64; // m16bcst (Be careful) + static const uint64_t T_M_K = 1ull << 28; // mem{k} + static const uint64_t T_VSIB = 1ull << 29; + static const uint64_t T_MEM_EVEX = 1ull << 30; // use evex if mem + static const uint64_t T_FP16 = 1ull << 31; // avx512-fp16 + static const uint64_t T_MAP5 = T_FP16 | T_0F; + static const uint64_t T_MAP6 = T_FP16 | T_0F38; + static const uint64_t T_NF = 1ull << 32; // T_nf + static const uint64_t T_CODE1_IF1 = 1ull << 33; // code|=1 if !r.isBit(8) + + static const uint64_t T_ND1 = 1ull << 35; // ND=1 + static const uint64_t T_ZU = 1ull << 36; // ND=ZU + static const uint64_t T_F2 = 1ull << 37; // pp = 3 + // T_66 = 1, T_F3 = 2, T_F2 = 3 + static inline uint32_t getPP(uint64_t type) { return (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; } + // @@@end of avx_type_def.h + static inline uint32_t getMap(uint64_t type) { return (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; } + void vex(const Reg& reg, const Reg& base, const Operand *v, uint64_t type, int code, bool x = false) + { + int w = (type & T_W1) ? 1 : 0; + bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); + bool r = reg.isExtIdx(); + bool b = base.isExtIdx(); + int idx = v ? v->getIdx() : 0; + if ((idx | reg.getIdx() | base.getIdx()) >= 16) XBYAK_THROW(ERR_BAD_COMBINATION) + uint32_t pp = getPP(type); + uint32_t vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; + if (!b && !x && !w && (type & T_0F)) { + db(0xC5); db((r ? 0 : 0x80) | vvvv); + } else { + uint32_t mmmm = getMap(type); + db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); + } + db(code); + } + void verifySAE(const Reg& r, uint64_t type) const + { + if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; + XBYAK_THROW(ERR_SAE_IS_INVALID) + } + void verifyER(const Reg& r, uint64_t type) const + { + if ((type & T_ER_R) && r.isREG(32|64)) return; + if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; + XBYAK_THROW(ERR_ER_IS_INVALID) + } + // (a, b, c) contains non zero two or three values then err + int verifyDuplicate(int a, int b, int c, int err) + { + int v = a | b | c; + if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) XBYAK_THROW_RET(err, 0) + return v; + } + int evex(const Reg& reg, const Reg& base, const Operand *v, uint64_t type, int code, const Reg *x = 0, bool b = false, int aaa = 0, uint32_t VL = 0, bool Hi16Vidx = false) + { + if (!(type & (T_EVEX | T_MUST_EVEX))) XBYAK_THROW_RET(ERR_EVEX_IS_INVALID, 0) + int w = (type & T_EW1) ? 1 : 0; + uint32_t mmm = getMap(type); + if (type & T_FP16) mmm |= 4; + uint32_t pp = getPP(type); + int idx = v ? v->getIdx() : 0; + uint32_t vvvv = ~idx; + + bool R = reg.isExtIdx(); + bool X3 = (x && x->isExtIdx()) || (base.isSIMD() && base.isExtIdx2()); + bool B4 = base.isREG() && base.isExtIdx2(); + bool X4 = x && (x->isREG() && x->isExtIdx2()); + bool B = base.isExtIdx(); + bool Rp = reg.isExtIdx2(); + int LL; + int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); + int disp8N = 1; + if (rounding) { + if (rounding == EvexModifierRounding::T_SAE) { + verifySAE(base, type); LL = 0; + } else { + verifyER(base, type); LL = rounding - 1; + } + b = true; + } else { + if (v) VL = (std::max)(VL, v->getBit()); + VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); + LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; + if (b) { + disp8N = ((type & T_B16) == T_B16) ? 2 : (type & T_B32) ? 4 : 8; + } else if ((type & T_NX_MASK) == T_DUP) { + disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; + } else { + if ((type & (T_NX_MASK | T_N_VL)) == 0) { + type |= T_N16 | T_N_VL; // default + } + int low = type & T_NX_MASK; + if (low > 0) { + disp8N = 1 << (low - 1); + if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); + } + } + } + bool V4 = ((v ? v->isExtIdx2() : 0) || Hi16Vidx); + bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); + if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); + if (aaa == 0) z = 0; // clear T_z if mask is not set + db(0x62); + db((R ? 0 : 0x80) | (X3 ? 0 : 0x40) | (B ? 0 : 0x20) | (Rp ? 0 : 0x10) | (B4 ? 8 : 0) | mmm); + db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | (X4 ? 0 : 4) | (pp & 3)); + db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (V4 ? 0 : 8) | (aaa & 7)); + db(code); + return disp8N; + } + // evex of Legacy + void evexLeg(const Reg& r, const Reg& b, const Reg& x, const Reg& v, uint64_t type, int sc = NONE) + { + int M = getMap(type); if (M == 0) M = 4; // legacy + int R3 = !r.isExtIdx(); + int X3 = !x.isExtIdx(); + int B3 = b.isExtIdx() ? 0 : 0x20; + int R4 = r.isExtIdx2() ? 0 : 0x10; + int B4 = b.isExtIdx2() ? 0x08 : 0; + int w = (type & T_W0) ? 0 : (r.isBit(64) || v.isBit(64) || (type & T_W1)); + int V = (~v.getIdx() & 15) << 3; + int X4 = x.isExtIdx2() ? 0 : 0x04; + int pp = (type & (T_F2|T_F3|T_66)) ? getPP(type) : (r.isBit(16) || v.isBit(16)); + int V4 = !v.isExtIdx2(); + int ND = (type & T_ZU) ? (r.getZU() || b.getZU()) : (type & T_ND1) ? 1 : (type & T_APX) ? 0 : v.isREG(); + int NF = r.getNF() | b.getNF() | x.getNF() | v.getNF(); + int L = 0; + if ((type & T_NF) == 0 && NF) XBYAK_THROW(ERR_INVALID_NF) + if ((type & T_ZU) == 0 && r.getZU()) XBYAK_THROW(ERR_INVALID_ZU) + db(0x62); + db((R3<<7) | (X3<<6) | B3 | R4 | B4 | M); + db((w<<7) | V | X4 | pp); + if (sc != NONE) { + db((L<<5) | (ND<<4) | sc); + } else { + db((L<<5) | (ND<<4) | (V4<<3) | (NF<<2)); + } + } + void setModRM(int mod, int r1, int r2) + { + db(static_cast((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); + } + void setSIB(const RegExp& e, int reg, int disp8N = 0) + { + uint64_t disp64 = e.getDisp(); +#if defined(XBYAK64) && !defined(__ILP32__) +#ifdef XBYAK_OLD_DISP_CHECK + // treat 0xffffffff as 0xffffffffffffffff + uint64_t high = disp64 >> 32; + if (high != 0 && high != 0xFFFFFFFF) XBYAK_THROW(ERR_OFFSET_IS_TOO_BIG) +#else + // displacement should be a signed 32-bit value, so also check sign bit + uint64_t high = disp64 >> 31; + if (high != 0 && high != 0x1FFFFFFFF) XBYAK_THROW(ERR_OFFSET_IS_TOO_BIG) +#endif +#endif + uint32_t disp = static_cast(disp64); + const Reg& base = e.getBase(); + const Reg& index = e.getIndex(); + const int baseIdx = base.getIdx(); + const int baseBit = base.getBit(); + const int indexBit = index.getBit(); + enum { + mod00 = 0, mod01 = 1, mod10 = 2 + }; + int mod = mod10; // disp32 + if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { + mod = mod00; + } else { + if (disp8N == 0) { + if (inner::IsInDisp8(disp)) { + mod = mod01; + } + } else { + // disp must be casted to signed + uint32_t t = static_cast(static_cast(disp) / disp8N); + if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { + disp = t; + mod = mod01; + } + } + } + const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; + /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ + bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; +#ifdef XBYAK64 + if (!baseBit && !indexBit) hasSIB = true; +#endif + if (hasSIB) { + setModRM(mod, reg, Operand::ESP); + /* SIB = [2:3:3] = [SS:index:base(=rm)] */ + const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; + const int scale = e.getScale(); + const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; + setModRM(SS, idx, newBaseIdx); + } else { + setModRM(mod, reg, newBaseIdx); + } + if (mod == mod01) { + db(disp); + } else if (mod == mod10 || (mod == mod00 && !baseBit)) { + dd(disp); + } + } + LabelManager labelMgr_; + bool isInDisp16(uint32_t x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } + void writeCode(uint64_t type, const Reg& r, int code, bool rex2 = false) + { + if (!(type&T_APX || rex2)) { + if (type & T_0F) { + db(0x0F); + } else if (type & T_0F38) { + db(0x0F); db(0x38); + } else if (type & T_0F3A) { + db(0x0F); db(0x3A); + } + } + db(code | ((type == 0 || (type & T_CODE1_IF1)) && !r.isBit(8))); + } + void opRR(const Reg& reg1, const Reg& reg2, uint64_t type, int code) + { + bool rex2 = rex(reg2, reg1, type); + writeCode(type, reg1, code, rex2); + setModRM(3, reg1.getIdx(), reg2.getIdx()); + } + void opMR(const Address& addr, const Reg& r, uint64_t type, int code, uint64_t type2 = 0, int code2 = NONE) + { + if (code2 == NONE) code2 = code; + if (type2 && opROO(Reg(), addr, r, type2, code2)) return; + if (addr.is64bitDisp()) XBYAK_THROW(ERR_CANT_USE_64BIT_DISP) + bool rex2 = rex(addr, r, type); + writeCode(type, r, code, rex2); + opAddr(addr, r.getIdx()); + } + void opLoadSeg(const Address& addr, const Reg& reg, uint64_t type, int code) + { + if (reg.isBit(8)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + if (addr.is64bitDisp()) XBYAK_THROW(ERR_CANT_USE_64BIT_DISP) + // can't use opMR + rex(addr, reg, type); + if (type & T_0F) db(0x0F); + db(code); + opAddr(addr, reg.getIdx()); + } + // for only MPX(bnd*) + void opMIB(const Address& addr, const Reg& reg, uint64_t type, int code) + { + if (addr.getMode() != Address::M_ModRM) XBYAK_THROW(ERR_INVALID_MIB_ADDRESS) + opMR(addr.cloneNoOptimize(), reg, type, code); + } + void makeJmp(uint32_t disp, LabelType type, uint8_t shortCode, uint8_t longCode, uint8_t longPref) + { + const int shortJmpSize = 2; + const int longHeaderSize = longPref ? 2 : 1; + const int longJmpSize = longHeaderSize + 4; + if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { + db(shortCode); db(disp - shortJmpSize); + } else { + if (type == T_SHORT) XBYAK_THROW(ERR_LABEL_IS_TOO_FAR) + if (longPref) db(longPref); + db(longCode); dd(disp - longJmpSize); + } + } + bool isNEAR(LabelType type) const { return type == T_NEAR || (type == T_AUTO && isDefaultJmpNEAR_); } + template + void opJmp(T& label, LabelType type, uint8_t shortCode, uint8_t longCode, uint8_t longPref) + { + if (type == T_FAR) XBYAK_THROW(ERR_NOT_SUPPORTED) + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { /* label exists */ + makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); + } else { + int jmpSize = 0; + if (isNEAR(type)) { + jmpSize = 4; + if (longPref) db(longPref); + db(longCode); dd(0); + } else { + jmpSize = 1; + db(shortCode); db(0); + } + JmpLabel jmp(size_, jmpSize, inner::LasIs); + labelMgr_.addUndefinedLabel(label, jmp); + } + } + void opJmpAbs(const void *addr, LabelType type, uint8_t shortCode, uint8_t longCode, uint8_t longPref = 0) + { + if (type == T_FAR) XBYAK_THROW(ERR_NOT_SUPPORTED) + if (isAutoGrow()) { + if (!isNEAR(type)) XBYAK_THROW(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW) + if (size_ + 16 >= maxSize_) growMemory(); + if (longPref) db(longPref); + db(longCode); + dd(0); + save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); + } else { + makeJmp(inner::VerifyInInt32(reinterpret_cast(addr) - getCurr()), type, shortCode, longCode, longPref); + } + + } + void opJmpOp(const Operand& op, LabelType type, int ext) + { + const int bit = 16|i32e; + if (type == T_FAR) { + if (!op.isMEM(bit)) XBYAK_THROW(ERR_NOT_SUPPORTED) + opRext(op, bit, ext + 1, 0, 0xFF, false); + } else { + opRext(op, bit, ext, 0, 0xFF, true); + } + } + // reg is reg field of ModRM + // immSize is the size for immediate value + void opAddr(const Address &addr, int reg) + { + if (!addr.permitVsib && addr.isVsib()) XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING) + if (addr.getMode() == Address::M_ModRM) { + setSIB(addr.getRegExp(), reg, addr.disp8N); + } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { + setModRM(0, reg, 5); + if (addr.getLabel()) { // [rip + Label] + putL_inner(*addr.getLabel(), true, addr.getDisp() - addr.immSize); + } else { + size_t disp = addr.getDisp(); + if (addr.getMode() == Address::M_ripAddr) { + if (isAutoGrow()) XBYAK_THROW(ERR_INVALID_RIP_IN_AUTO_GROW) + disp -= (size_t)getCurr() + 4 + addr.immSize; + } + dd(inner::VerifyInInt32(disp)); + } + } + } + void opSSE(const Reg& r, const Operand& op, uint64_t type, int code, bool isValid(const Operand&, const Operand&), int imm8 = NONE) + { + if (isValid && !isValid(r, op)) XBYAK_THROW(ERR_BAD_COMBINATION) + if (!isValidSSE(r) || !isValidSSE(op)) XBYAK_THROW(ERR_NOT_SUPPORTED) + opRO(r, op, type, code, true, (imm8 != NONE) ? 1 : 0); + if (imm8 != NONE) db(imm8); + } + void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) + { + if (!isValidSSE(mmx)) XBYAK_THROW(ERR_NOT_SUPPORTED) + uint64_t type = T_0F; + if (mmx.isXMM()) type |= T_66; + opRR(Reg32(ext), mmx, type, code); + db(imm8); + } + void opMMX(const Mmx& mmx, const Operand& op, int code, uint64_t type = T_0F, uint64_t pref = T_66, int imm8 = NONE) + { + if (mmx.isXMM()) type |= pref; + opSSE(mmx, op, type, code, isXMMorMMX_MEM, imm8); + } + void opMovXMM(const Operand& op1, const Operand& op2, uint64_t type, int code) + { + if (!isValidSSE(op1) || !isValidSSE(op2)) XBYAK_THROW(ERR_NOT_SUPPORTED) + if (op1.isXMM() && op2.isMEM()) { + opMR(op2.getAddress(), op1.getReg(), type, code); + } else if (op1.isMEM() && op2.isXMM()) { + opMR(op1.getAddress(), op2.getReg(), type, code | 1); + } else { + XBYAK_THROW(ERR_BAD_COMBINATION) + } + } + // pextr{w,b,d}, extractps + void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) + { + if (!isValidSSE(op) || !isValidSSE(mmx)) XBYAK_THROW(ERR_NOT_SUPPORTED) + if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ + if (mmx.isXMM()) db(0x66); + opRR(op.getReg(), mmx, T_0F, 0xC5); db(imm); + } else { + opSSE(mmx, op, T_66 | T_0F3A, code, isXMM_REG32orMEM, imm); + } + } + // (r, r, m) or (r, m, r) + bool opROO(const Reg& d, const Operand& op1, const Operand& op2, uint64_t type, int code, int immSize = 0, int sc = NONE) + { + if (!(type & T_MUST_EVEX) && !d.isREG() && !(d.hasRex2NFZU() || op1.hasRex2NFZU() || op2.hasRex2NFZU())) return false; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) { std::swap(p1, p2); } else { if (p2->isMEM()) code |= 2; } + if (p1->isMEM()) XBYAK_THROW_RET(ERR_BAD_COMBINATION, false) + if (p2->isMEM()) { + const Reg& r = *static_cast(p1); + Address addr = p2->getAddress(); + const RegExp e = addr.getRegExp(); + evexLeg(r, e.getBase(), e.getIndex(), d, type, sc); + writeCode(type, d, code); + addr.immSize = immSize; + opAddr(addr, r.getIdx()); + } else { + evexLeg(static_cast(op2), static_cast(op1), Reg(), d, type, sc); + writeCode(type, d, code); + setModRM(3, op2.getIdx(), op1.getIdx()); + } + return true; + } + void opRext(const Operand& op, int bit, int ext, uint64_t type, int code, bool disableRex = false, int immSize = 0, const Reg *d = 0) + { + int opBit = op.getBit(); + if (disableRex && opBit == 64) opBit = 32; + const Reg r(ext, Operand::REG, opBit); + if ((type & T_APX) && op.hasRex2NFZU() && opROO(d ? *d : Reg(0, Operand::REG, opBit), op, r, type, code)) return; + if (op.isMEM()) { + opMR(op.getAddress(immSize), r, type, code); + } else if (op.isREG(bit)) { + opRR(r, op.getReg().changeBit(opBit), type, code); + } else { + XBYAK_THROW(ERR_BAD_COMBINATION) + } + } + void opSetCC(const Operand& op, int ext) + { + if (opROO(Reg(), op, Reg(), T_APX|T_ZU|T_F2, 0x40 | ext)) return; + opRext(op, 8, 0, T_0F, 0x90 | ext); + } + void opShift(const Operand& op, int imm, int ext, const Reg *d = 0) + { + if (d == 0) verifyMemHasSize(op); + if (d && op.getBit() != 0 && d->getBit() != op.getBit()) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + uint64_t type = T_APX|T_CODE1_IF1; if (ext & 8) type |= T_NF; if (d) type |= T_ND1; + opRext(op, 0, ext&7, type, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), false, (imm != 1) ? 1 : 0, d); + if (imm != 1) db(imm); + } + void opShift(const Operand& op, const Reg8& _cl, int ext, const Reg *d = 0) + { + if (_cl.getIdx() != Operand::CL) XBYAK_THROW(ERR_BAD_COMBINATION) + if (d && op.getBit() != 0 && d->getBit() != op.getBit()) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + uint64_t type = T_APX|T_CODE1_IF1; if (ext & 8) type |= T_NF; if (d) type |= T_ND1; + opRext(op, 0, ext&7, type, 0xD2, false, 0, d); + } + // condR assumes that op.isREG() is true + void opRO(const Reg& r, const Operand& op, uint64_t type, int code, bool condR = true, int immSize = 0) + { + if (op.isMEM()) { + opMR(op.getAddress(immSize), r, type, code); + } else if (condR) { + opRR(r, op.getReg(), type, code); + } else { + XBYAK_THROW(ERR_BAD_COMBINATION) + } + } + void opShxd(const Reg& d, const Operand& op, const Reg& reg, uint8_t imm, int code, int code2, const Reg8 *_cl = 0) + { + if (_cl && _cl->getIdx() != Operand::CL) XBYAK_THROW(ERR_BAD_COMBINATION) + if (!reg.isREG(16|i32e)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + int immSize = _cl ? 0 : 1; + if (_cl) code |= 1; + uint64_t type = T_APX | T_NF; + if (d.isREG()) type |= T_ND1; + if (!opROO(d, op, reg, type, _cl ? code : code2, immSize)) { + opRO(reg, op, T_0F, code, true, immSize); + } + if (!_cl) db(imm); + } + // (REG, REG|MEM), (MEM, REG) + void opRO_MR(const Operand& op1, const Operand& op2, int code) + { + if (op2.isMEM()) { + if (!op1.isREG()) XBYAK_THROW(ERR_BAD_COMBINATION) + opMR(op2.getAddress(), op1.getReg(), 0, code | 2); + } else { + opRO(static_cast(op2), op1, 0, code, op1.getKind() == op2.getKind()); + } + } + uint32_t getImmBit(const Operand& op, uint32_t imm) + { + verifyMemHasSize(op); + uint32_t immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; + if (op.isBit(8)) immBit = 8; + if (op.getBit() < immBit) XBYAK_THROW_RET(ERR_IMM_IS_TOO_BIG, 0) + if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ + return immBit; + } + // (REG|MEM, IMM) + void opOI(const Operand& op, uint32_t imm, int code, int ext) + { + uint32_t immBit = getImmBit(op, imm); + if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al + rex(op); + db(code | 4 | (immBit == 8 ? 0 : 1)); + } else { + int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; + opRext(op, 0, ext, 0, 0x80 | tmp, false, immBit / 8); + } + db(imm, immBit / 8); + } + // (r, r/m, imm) + void opROI(const Reg& d, const Operand& op, uint32_t imm, uint64_t type, int ext) + { + uint32_t immBit = getImmBit(d, imm); + int code = immBit < (std::min)(d.getBit(), 32U) ? 2 : 0; + opROO(d, op, Reg(ext, Operand::REG, d.getBit()), type, 0x80 | code, immBit / 8); + db(imm, immBit / 8); + } + void opIncDec(const Reg& d, const Operand& op, int ext) + { +#ifdef XBYAK64 + if (d.isREG()) { + int code = d.isBit(8) ? 0xFE : 0xFF; + uint64_t type = T_APX|T_NF|T_ND1; + if (d.isBit(16)) type |= T_66; + opROO(d, op, Reg(ext, Operand::REG, d.getBit()), type, code); + return; + } +#else + (void)d; +#endif + verifyMemHasSize(op); +#ifndef XBYAK64 + if (op.isREG() && !op.isBit(8)) { + rex(op); db((ext ? 0x48 : 0x40) | op.getIdx()); + return; + } +#endif + opRext(op, op.getBit(), ext, 0, 0xFE); + } + void opPushPop(const Operand& op, int code, int ext, int alt) + { + if (op.isREG() && op.hasRex2()) { + const Reg& r = static_cast(op); + rex2(0, rexRXB(3, 0, Reg(), r), Reg(), r); + db(alt); + return; + } + int bit = op.getBit(); + if (bit == 16 || bit == BIT) { + if (bit == 16) db(0x66); + if (op.isREG()) { + if (op.getReg().getIdx() >= 8) db(0x41); + db(alt | (op.getIdx() & 7)); + return; + } + if (op.isMEM()) { + opMR(op.getAddress(), Reg(ext, Operand::REG, 32), 0, code); + return; + } + } + XBYAK_THROW(ERR_BAD_COMBINATION) + } + void verifyMemHasSize(const Operand& op) const + { + if (op.isMEM() && op.getBit() == 0) XBYAK_THROW(ERR_MEM_SIZE_IS_NOT_SPECIFIED) + } + /* + mov(r, imm) = db(imm, mov_imm(r, imm)) + */ + int mov_imm(const Reg& reg, uint64_t imm) + { + int bit = reg.getBit(); + const int idx = reg.getIdx(); + int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); + if (bit == 64 && (imm & ~uint64_t(0xffffffffu)) == 0) { + rex(Reg32(idx)); + bit = 32; + } else { + rex(reg); + if (bit == 64 && inner::IsInInt32(imm)) { + db(0xC7); + code = 0xC0; + bit = 32; + } + } + db(code | (idx & 7)); + return bit / 8; + } + template + void putL_inner(T& label, bool relative = false, size_t disp = 0) + { + const int jmpSize = relative ? 4 : (int)sizeof(size_t); + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { + if (relative) { + db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); + } else if (isAutoGrow()) { + db(uint64_t(0), jmpSize); + save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); + } else { + db(size_t(top_) + offset, jmpSize); + } + return; + } + db(uint64_t(0), jmpSize); + JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); + labelMgr_.addUndefinedLabel(label, jmp); + } + void opMovxx(const Reg& reg, const Operand& op, uint8_t code) + { + if (op.isBit(32)) XBYAK_THROW(ERR_BAD_COMBINATION) + int w = op.isBit(16); + if (!(reg.isREG() && (reg.getBit() > op.getBit()))) XBYAK_THROW(ERR_BAD_COMBINATION) + opRO(reg, op, T_0F, code | w); + } + void opFpuMem(const Address& addr, uint8_t m16, uint8_t m32, uint8_t m64, uint8_t ext, uint8_t m64ext) + { + if (addr.is64bitDisp()) XBYAK_THROW(ERR_CANT_USE_64BIT_DISP) + uint8_t code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; + if (!code) XBYAK_THROW(ERR_BAD_MEM_SIZE) + if (m64ext && addr.isBit(64)) ext = m64ext; + rex(addr, st0); + db(code); + opAddr(addr, ext); + } + // use code1 if reg1 == st0 + // use code2 if reg1 != st0 && reg2 == st0 + void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32_t code1, uint32_t code2) + { + uint32_t code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; + if (!code) XBYAK_THROW(ERR_BAD_ST_COMBINATION) + db(uint8_t(code >> 8)); + db(uint8_t(code | (reg1.getIdx() | reg2.getIdx()))); + } + void opFpu(const Fpu& reg, uint8_t code1, uint8_t code2) + { + db(code1); db(code2 | reg.getIdx()); + } + void opVex(const Reg& r, const Operand *p1, const Operand& op2, uint64_t type, int code, int imm8 = NONE) + { + if (op2.isMEM()) { + Address addr = op2.getAddress(); + const RegExp& regExp = addr.getRegExp(); + const Reg& base = regExp.getBase(); + const Reg& index = regExp.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + int disp8N = 0; + if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx() || addr.hasRex2()) { + int aaa = addr.getOpmaskIdx(); + if (aaa && !(type & T_M_K)) XBYAK_THROW(ERR_INVALID_OPMASK_WITH_MEMORY) + bool b = false; + if (addr.isBroadcast()) { + if (!(type & (T_B32 | T_B64))) XBYAK_THROW(ERR_INVALID_BROADCAST) + b = true; + } + int VL = regExp.isVsib() ? index.getBit() : 0; + disp8N = evex(r, base, p1, type, code, &index, b, aaa, VL, index.isSIMD() && index.isExtIdx2()); + } else { + vex(r, base, p1, type, code, index.isExtIdx()); + } + if (type & T_VSIB) addr.permitVsib = true; + if (disp8N) addr.disp8N = disp8N; + if (imm8 != NONE) addr.immSize = 1; + opAddr(addr, r.getIdx()); + } else { + const Reg& base = op2.getReg(); + if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { + evex(r, base, p1, type, code); + } else { + vex(r, base, p1, type, code); + } + setModRM(3, r.getIdx(), base.getIdx()); + } + if (imm8 != NONE) db(imm8); + } + // (r, r, r/m) + // opRRO(a, b, c) == opROO(b, c, a) + void opRRO(const Reg& d, const Reg& r1, const Operand& op2, uint64_t type, uint8_t code, int imm8 = NONE) + { + const unsigned int bit = d.getBit(); + if (r1.getBit() != bit || (op2.isREG() && op2.getBit() != bit)) XBYAK_THROW(ERR_BAD_COMBINATION) + type |= (bit == 64) ? T_W1 : T_W0; + if (d.hasRex2() || r1.hasRex2() || op2.hasRex2() || d.getNF()) { + opROO(r1, op2, d, type, code); + if (imm8 != NONE) db(imm8); + } else { + opVex(d, &r1, op2, type, code, imm8); + } + } + void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, uint64_t type, int code, int imm8 = NONE) + { + const Xmm *x2 = static_cast(&op1); + const Operand *op = &op2; + if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) + x2 = &x1; + op = &op1; + } + // (x1, x2, op) + if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) XBYAK_THROW(ERR_BAD_COMBINATION) + opVex(x1, x2, *op, type, code, imm8); + } + void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, uint64_t type, int code, int imm8 = NONE) + { + if (!op3.isMEM() && (x2.getKind() != op3.getKind())) XBYAK_THROW(ERR_BAD_COMBINATION) + opVex(k, &x2, op3, type, code, imm8); + } + // (x, x/m), (y, x/m256), (z, y/m) + void checkCvt1(const Operand& x, const Operand& op) const + { + if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) XBYAK_THROW(ERR_BAD_COMBINATION) + } + // (x, x/m), (x, y/m256), (y, z/m) + void checkCvt2(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) XBYAK_THROW(ERR_BAD_COMBINATION) + } + void opCvt(const Xmm& x, const Operand& op, uint64_t type, int code) + { + Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + void opCvt2(const Xmm& x, const Operand& op, uint64_t type, int code) + { + checkCvt2(x, op); + opCvt(x, op, type, code); + } + void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, uint64_t type, uint64_t type64, uint64_t type32, uint8_t code) + { + if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + Xmm x(op.getIdx()); + const Operand *p = op.isREG() ? &x : &op; + opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); + } + // (x, x/y/xword/yword), (y, z/m) + void checkCvt4(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM) && op.isBit(128|256)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) XBYAK_THROW(ERR_BAD_COMBINATION) + } + // (x, x/y/z/xword/yword/zword) + void opCvt5(const Xmm& x, const Operand& op, uint64_t type, int code) + { + if (!(x.isXMM() && op.isBit(128|256|512))) XBYAK_THROW(ERR_BAD_COMBINATION) + Operand::Kind kind = op.isBit(128) ? Operand::XMM : op.isBit(256) ? Operand::YMM : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + const Xmm& cvtIdx0(const Operand& x) const + { + return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; + } + // support (x, x/m, imm), (y, y/m, imm) + void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, uint64_t type, int code, int imm8 = NONE) + { + opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); + } + void opCnt(const Reg& reg, const Operand& op, uint8_t code) + { + if (reg.isBit(8)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); + if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) XBYAK_THROW(ERR_BAD_COMBINATION) + if (is16bit) db(0x66); + opRO(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, T_F3 | T_0F, code); + } + void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, uint64_t type, uint8_t code, int mode) + { + const RegExp& regExp = addr.getRegExp(); + if (!regExp.isVsib(128 | 256)) XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING) + const int y_vx_y = 0; + const int y_vy_y = 1; +// const int x_vy_x = 2; + const bool isAddrYMM = regExp.getIndex().getBit() == 256; + if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { + bool isOK = false; + if (mode == y_vx_y) { + isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); + } else if (mode == y_vy_y) { + isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); + } else { // x_vy_x + isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); + } + if (!isOK) XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING) + } + int i1 = x1.getIdx(); + int i2 = regExp.getIndex().getIdx(); + int i3 = x2.getIdx(); + if (i1 == i2 || i1 == i3 || i2 == i3) XBYAK_THROW(ERR_SAME_REGS_ARE_INVALID); + opAVX_X_X_XM(isAddrYMM ? Ymm(i1) : x1, isAddrYMM ? Ymm(i3) : x2, addr, type, code); + } + enum { + xx_yy_zz = 0, + xx_yx_zy = 1, + xx_xy_yz = 2 + }; + void checkGather2(const Xmm& x1, const Reg& x2, int mode) const + { + if (x1.isXMM() && x2.isXMM()) return; + switch (mode) { + case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; + break; + case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; + break; + case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; + break; + } + XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING) + } + void opGather2(const Xmm& x, const Address& addr, uint64_t type, uint8_t code, int mode) + { + if (x.hasZero()) XBYAK_THROW(ERR_INVALID_ZERO) + const RegExp& regExp = addr.getRegExp(); + checkGather2(x, regExp.getIndex(), mode); + int maskIdx = x.getOpmaskIdx(); + if ((type & T_M_K) && addr.getOpmaskIdx()) maskIdx = addr.getOpmaskIdx(); + if (maskIdx == 0) XBYAK_THROW(ERR_K0_IS_INVALID); + if (!(type & T_M_K) && x.getIdx() == regExp.getIndex().getIdx()) XBYAK_THROW(ERR_SAME_REGS_ARE_INVALID); + opVex(x, 0, addr, type, code); + } + /* + xx_xy_yz ; mode = true + xx_xy_xz ; mode = false + */ + void opVmov(const Operand& op, const Xmm& x, uint64_t type, uint8_t code, bool mode) + { + if (mode) { + if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) XBYAK_THROW(ERR_BAD_COMBINATION) + } else { + if (!op.isMEM() && !op.isXMM()) XBYAK_THROW(ERR_BAD_COMBINATION) + } + opVex(x, 0, op, type, code); + } + void opGatherFetch(const Address& addr, const Xmm& x, uint64_t type, uint8_t code, Operand::Kind kind) + { + if (addr.hasZero()) XBYAK_THROW(ERR_INVALID_ZERO) + if (addr.getRegExp().getIndex().getKind() != kind) XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING) + opVex(x, 0, addr, type, code); + } + void opEncoding(const Xmm& x1, const Xmm& x2, const Operand& op, uint64_t type, int code, PreferredEncoding encoding) + { + opAVX_X_X_XM(x1, x2, op, type | orEvexIf(encoding), code); + } + int orEvexIf(PreferredEncoding encoding) { + if (encoding == DefaultEncoding) { + encoding = defaultEncoding_; + } + if (encoding == EvexEncoding) { +#ifdef XBYAK_DISABLE_AVX512 + XBYAK_THROW(ERR_EVEX_IS_INVALID) +#endif + return T_MUST_EVEX; + } + return 0; + } + void opInOut(const Reg& a, const Reg& d, uint8_t code) + { + if (a.getIdx() == Operand::AL && d.getIdx() == Operand::DX && d.getBit() == 16) { + switch (a.getBit()) { + case 8: db(code); return; + case 16: db(0x66); db(code + 1); return; + case 32: db(code + 1); return; + } + } + XBYAK_THROW(ERR_BAD_COMBINATION) + } + void opInOut(const Reg& a, uint8_t code, uint8_t v) + { + if (a.getIdx() == Operand::AL) { + switch (a.getBit()) { + case 8: db(code); db(v); return; + case 16: db(0x66); db(code + 1); db(v); return; + case 32: db(code + 1); db(v); return; + } + } + XBYAK_THROW(ERR_BAD_COMBINATION) + } + void opCcmp(const Operand& op1, const Operand& op2, int dfv, int code, int sc) // cmp = 0x38, test = 0x84 + { + if (dfv < 0 || 15 < dfv) XBYAK_THROW(ERR_INVALID_DFV) + opROO(Reg(15 - dfv, Operand::REG, (op1.getBit() | op2.getBit())), op1, op2, T_APX|T_CODE1_IF1, code, 0, sc); + } + void opCcmpi(const Operand& op, int imm, int dfv, int sc) + { + if (dfv < 0 || 15 < dfv) XBYAK_THROW(ERR_INVALID_DFV) + uint32_t immBit = getImmBit(op, imm); + uint32_t opBit = op.getBit(); + int tmp = immBit < (std::min)(opBit, 32U) ? 2 : 0; + opROO(Reg(15 - dfv, Operand::REG, opBit), op, Reg(15, Operand::REG, opBit), T_APX|T_CODE1_IF1, 0x80 | tmp, immBit / 8, sc); + db(imm, immBit / 8); + } + void opTesti(const Operand& op, int imm, int dfv, int sc) + { + if (dfv < 0 || 15 < dfv) XBYAK_THROW(ERR_INVALID_DFV) + uint32_t opBit = op.getBit(); + if (opBit == 0) XBYAK_THROW(ERR_MEM_SIZE_IS_NOT_SPECIFIED); + int immBit = (std::min)(opBit, 32U); + opROO(Reg(15 - dfv, Operand::REG, opBit), op, Reg(0, Operand::REG, opBit), T_APX|T_CODE1_IF1, 0xF6, immBit / 8, sc); + db(imm, immBit / 8); + } + void opCfcmov(const Reg& d, const Operand& op1, const Operand& op2, int code) + { + const int dBit = d.getBit(); + const int op2Bit = op2.getBit(); + if (dBit > 0 && op2Bit > 0 && dBit != op2Bit) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + if (op1.isBit(8) || op2Bit == 8) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) + if (op2.isMEM()) { + if (op1.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) + uint64_t type = dBit > 0 ? (T_MUST_EVEX|T_NF) : T_MUST_EVEX; + opROO(d, op2, op1, type, code); + } else { + opROO(d, op1, static_cast(op2)|T_nf, T_MUST_EVEX|T_NF, code); + } + } +#ifdef XBYAK64 + void opAMX(const Tmm& t1, const Address& addr, uint64_t type, int code) + { + // require both base and index + Address addr2 = addr.cloneNoOptimize(); + const RegExp exp = addr2.getRegExp(); + if (exp.getBase().getBit() == 0 || exp.getIndex().getBit() == 0) XBYAK_THROW(ERR_NOT_SUPPORTED) + if (opROO(Reg(), addr2, t1, T_APX|type, code)) return; + opVex(t1, &tmm0, addr2, type, code); + } +#endif + // (reg32e/mem, k) if rev else (k, k/mem/reg32e) + // size = 8, 16, 32, 64 + void opKmov(const Opmask& k, const Operand& op, bool rev, int size) + { + int code = 0; + bool isReg = op.isREG(size < 64 ? 32 : 64); + if (rev) { + code = isReg ? 0x93 : op.isMEM() ? 0x91 : 0; + } else { + code = op.isOPMASK() || op.isMEM() ? 0x90 : isReg ? 0x92 : 0; + } + if (code == 0) XBYAK_THROW(ERR_BAD_COMBINATION) + uint64_t type = T_0F; + switch (size) { + case 8: type |= T_W0|T_66; break; + case 16: type |= T_W0; break; + case 32: type |= isReg ? T_W0|T_F2 : T_W1|T_66; break; + case 64: type |= isReg ? T_W1|T_F2 : T_W1; break; + } + const Operand *p1 = &k, *p2 = &op; + if (code == 0x93) { std::swap(p1, p2); } + if (opROO(Reg(), *p2, *p1, T_APX|type, code)) return; + opVex(static_cast(*p1), 0, *p2, T_L0|type, code); + } + void opEncodeKey(const Reg32& r1, const Reg32& r2, uint8_t code1, uint8_t code2) + { + if (r1.getIdx() < 8 && r2.getIdx() < 8) { + db(0xF3); db(0x0F); db(0x38); db(code1); setModRM(3, r1.getIdx(), r2.getIdx()); + return; + } + opROO(Reg(), r2, r1, T_MUST_EVEX|T_F3, code2); + } + void opSSE_APX(const Xmm& x, const Operand& op, uint64_t type1, uint8_t code1, uint64_t type2, uint8_t code2, int imm = NONE) + { + if (x.getIdx() <= 15 && op.hasRex2() && opROO(Reg(), op, x, type2, code2, imm != NONE ? 1 : 0)) { + if (imm != NONE) db(imm); + return; + } + opSSE(x, op, type1, code1, isXMM_XMMorMEM, imm); + } +public: + unsigned int getVersion() const { return VERSION; } + using CodeArray::db; + const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; + const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; + const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; + const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; + const Zmm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; + const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; + const Reg16 ax, cx, dx, bx, sp, bp, si, di; + const Reg8 al, cl, dl, bl, ah, ch, dh, bh; + const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM + const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} + const Fpu st0, st1, st2, st3, st4, st5, st6, st7; + const Opmask k0, k1, k2, k3, k4, k5, k6, k7; + const BoundsReg bnd0, bnd1, bnd2, bnd3; + const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} + const EvexModifierZero T_z; // {z} + const ApxFlagNF T_nf; + const ApxFlagZU T_zu; +#ifdef XBYAK64 + const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; + const Reg64 r16, r17, r18, r19, r20, r21, r22, r23, r24, r25, r26, r27, r28, r29, r30, r31; + const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; + const Reg32 r16d, r17d, r18d, r19d, r20d, r21d, r22d, r23d, r24d, r25d, r26d, r27d, r28d, r29d, r30d, r31d; + const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; + const Reg16 r16w, r17w, r18w, r19w, r20w, r21w, r22w, r23w, r24w, r25w, r26w, r27w, r28w, r29w, r30w, r31w; + const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; + const Reg8 r16b, r17b, r18b, r19b, r20b, r21b, r22b, r23b, r24b, r25b, r26b, r27b, r28b, r29b, r30b, r31b; + const Reg8 spl, bpl, sil, dil; + const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; + const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; + const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; + const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; + const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; + const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; + const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; + const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + const Tmm tmm0, tmm1, tmm2, tmm3, tmm4, tmm5, tmm6, tmm7; + const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience + const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; + const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; + const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; + const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; + const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; + const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; + const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; + const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; + const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT + const Segment es, cs, ss, ds, fs, gs; +#endif +private: + bool isDefaultJmpNEAR_; + PreferredEncoding defaultEncoding_; +public: + void L(const std::string& label) { labelMgr_.defineSlabel(label); } + void L(Label& label) { labelMgr_.defineClabel(label); } + Label L() { Label label; L(label); return label; } + void inLocalLabel() { labelMgr_.enterLocal(); } + void outLocalLabel() { labelMgr_.leaveLocal(); } + /* + assign src to dst + require + dst : does not used by L() + src : used by L() + */ + void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } + /* + put address of label to buffer + @note the put size is 4(32-bit), 8(64-bit) + */ + void putL(std::string label) { putL_inner(label); } + void putL(const Label& label) { putL_inner(label); } + + // set default type of `jmp` of undefined label to T_NEAR + void setDefaultJmpNEAR(bool isNear) { isDefaultJmpNEAR_ = isNear; } + void jmp(const Operand& op, LabelType type = T_AUTO) { opJmpOp(op, type, 4); } + void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } + void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } + + void call(const Operand& op, LabelType type = T_AUTO) { opJmpOp(op, type, 2); } + // call(string label), not const std::string& + void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + void call(const char *label) { call(std::string(label)); } + void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + // call(function pointer) +#ifdef XBYAK_VARIADIC_TEMPLATE + template + void call(Ret(*func)(Params...)) { call(reinterpret_cast(func)); } +#endif + void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } + + void test(const Operand& op, const Reg& reg) + { + opRO(reg, op, 0, 0x84, op.getKind() == reg.getKind()); + } + void test(const Operand& op, uint32_t imm) + { + verifyMemHasSize(op); + int immSize = (std::min)(op.getBit() / 8, 4U); + if (op.isREG() && op.getIdx() == 0) { // al, ax, eax + rex(op); + db(0xA8 | (op.isBit(8) ? 0 : 1)); + } else { + opRext(op, 0, 0, 0, 0xF6, false, immSize); + } + db(imm, immSize); + } + void imul(const Reg& reg, const Operand& op, int imm) + { + int s = inner::IsInDisp8(imm) ? 1 : 0; + int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; + uint8_t code = uint8_t(0x69 | (s << 1)); + if (!opROO(Reg(), op, reg, T_APX|T_NF|T_ZU, code, immSize)) { + opRO(reg, op, 0, code, reg.getKind() == op.getKind(), immSize); + } + db(imm, immSize); + } + void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } + void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } + void push(const AddressFrame& af, uint32_t imm) + { + if (af.bit_ == 8) { + db(0x6A); db(imm); + } else if (af.bit_ == 16) { + db(0x66); db(0x68); dw(imm); + } else { + db(0x68); dd(imm); + } + } + /* use "push(word, 4)" if you want "push word 4" */ + void push(uint32_t imm) + { + if (inner::IsInDisp8(imm)) { + push(byte, imm); + } else { + push(dword, imm); + } + } + void mov(const Operand& op1, const Operand& op2) + { + const Reg *reg = 0; + const Address *addr = 0; + uint8_t code = 0; + if (op1.isREG() && op1.getIdx() == 0 && op2.isMEM()) { // mov eax|ax|al, [disp] + reg = &op1.getReg(); + addr= &op2.getAddress(); + code = 0xA0; + } else + if (op1.isMEM() && op2.isREG() && op2.getIdx() == 0) { // mov [disp], eax|ax|al + reg = &op2.getReg(); + addr= &op1.getAddress(); + code = 0xA2; + } +#ifdef XBYAK64 + if (addr && addr->is64bitDisp()) { + if (code) { + rex(*reg); + db(op1.isREG(8) ? 0xA0 : op1.isREG() ? 0xA1 : op2.isREG(8) ? 0xA2 : 0xA3); + db(addr->getDisp(), 8); + } else { + XBYAK_THROW(ERR_BAD_COMBINATION) + } + } else +#else + if (code && addr->isOnlyDisp()) { + rex(*reg, *addr); + db(code | (reg->isBit(8) ? 0 : 1)); + dd(static_cast(addr->getDisp())); + } else +#endif + { + opRO_MR(op1, op2, 0x88); + } + } + void mov(const Operand& op, uint64_t imm) + { + if (op.isREG()) { + const int size = mov_imm(op.getReg(), imm); + db(imm, size); + } else if (op.isMEM()) { + verifyMemHasSize(op); + int immSize = op.getBit() / 8; + if (immSize <= 4) { + int64_t s = int64_t(imm) >> (immSize * 8); + if (s != 0 && s != -1) XBYAK_THROW(ERR_IMM_IS_TOO_BIG) + } else { + if (!inner::IsInInt32(imm)) XBYAK_THROW(ERR_IMM_IS_TOO_BIG) + immSize = 4; + } + opMR(op.getAddress(immSize), Reg(0, Operand::REG, op.getBit()), 0, 0xC6); + db(static_cast(imm), immSize); + } else { + XBYAK_THROW(ERR_BAD_COMBINATION) + } + } + + // The template is used to avoid ambiguity when the 2nd argument is 0. + // When the 2nd argument is 0 the call goes to + // `void mov(const Operand& op, uint64_t imm)`. + template + void mov(const T1&, const T2 *) { T1::unexpected; } + void mov(const NativeReg& reg, const Label& label) + { + mov_imm(reg, dummyAddr); + putL(label); + } + void xchg(const Operand& op1, const Operand& op2) + { + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { + p1 = &op2; p2 = &op1; + } + if (p1->isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) + if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) +#ifdef XBYAK64 + && (p2->getIdx() != 0 || !p1->isREG(32)) +#endif + ) { + rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); + return; + } + opRO(static_cast(*p1), *p2, 0, 0x86 | (p1->isBit(8) ? 0 : 1), (p1->isREG() && (p1->getBit() == p2->getBit()))); + } + +#ifndef XBYAK_DISABLE_SEGMENT + void push(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x06); break; + case Segment::cs: db(0x0E); break; + case Segment::ss: db(0x16); break; + case Segment::ds: db(0x1E); break; + case Segment::fs: db(0x0F); db(0xA0); break; + case Segment::gs: db(0x0F); db(0xA8); break; + default: + assert(0); + } + } + void pop(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x07); break; + case Segment::cs: XBYAK_THROW(ERR_BAD_COMBINATION) + case Segment::ss: db(0x17); break; + case Segment::ds: db(0x1F); break; + case Segment::fs: db(0x0F); db(0xA1); break; + case Segment::gs: db(0x0F); db(0xA9); break; + default: + assert(0); + } + } + void putSeg(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x2E); break; + case Segment::cs: db(0x36); break; + case Segment::ss: db(0x3E); break; + case Segment::ds: db(0x26); break; + case Segment::fs: db(0x64); break; + case Segment::gs: db(0x65); break; + default: + assert(0); + } + } + void mov(const Operand& op, const Segment& seg) + { + opRO(Reg8(seg.getIdx()), op, 0, 0x8C, op.isREG(16|i32e)); + } + void mov(const Segment& seg, const Operand& op) + { + opRO(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast(op.getReg().cvt32()) : op, 0, 0x8E, op.isREG(16|i32e)); + } +#endif + + enum { NONE = 256 }; + // constructor + CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) + : CodeArray(maxSize, userPtr, allocator) + , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) + , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) + , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) + , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) + // for my convenience + , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) + , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) + , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) + + , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) + , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) + , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) + , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) + , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) + , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) + , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) + , bnd0(0), bnd1(1), bnd2(2), bnd3(3) + , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) + , T_z() + , T_nf() + , T_zu() +#ifdef XBYAK64 + , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) + , r16(Operand::R16), r17(Operand::R17), r18(Operand::R18), r19(Operand::R19), r20(Operand::R20), r21(Operand::R21), r22(Operand::R22), r23(Operand::R23), r24(Operand::R24), r25(Operand::R25), r26(Operand::R26), r27(Operand::R27), r28(Operand::R28), r29(Operand::R29), r30(Operand::R30), r31(Operand::R31) + , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) + , r16d(Operand::R16D), r17d(Operand::R17D), r18d(Operand::R18D), r19d(Operand::R19D), r20d(Operand::R20D), r21d(Operand::R21D), r22d(Operand::R22D), r23d(Operand::R23D), r24d(Operand::R24D), r25d(Operand::R25D), r26d(Operand::R26D), r27d(Operand::R27D), r28d(Operand::R28D), r29d(Operand::R29D), r30d(Operand::R30D), r31d(Operand::R31D) + , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) + , r16w(Operand::R16W), r17w(Operand::R17W), r18w(Operand::R18W), r19w(Operand::R19W), r20w(Operand::R20W), r21w(Operand::R21W), r22w(Operand::R22W), r23w(Operand::R23W), r24w(Operand::R24W), r25w(Operand::R25W), r26w(Operand::R26W), r27w(Operand::R27W), r28w(Operand::R28W), r29w(Operand::R29W), r30w(Operand::R30W), r31w(Operand::R31W) + , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) + , r16b(Operand::R16B), r17b(Operand::R17B), r18b(Operand::R18B), r19b(Operand::R19B), r20b(Operand::R20B), r21b(Operand::R21B), r22b(Operand::R22B), r23b(Operand::R23B), r24b(Operand::R24B), r25b(Operand::R25B), r26b(Operand::R26B), r27b(Operand::R27B), r28b(Operand::R28B), r29b(Operand::R29B), r30b(Operand::R30B), r31b(Operand::R31B) + , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) + , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) + , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) + , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) + , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) + , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) + , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) + , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) + , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) + , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) + , tmm0(0), tmm1(1), tmm2(2), tmm3(3), tmm4(4), tmm5(5), tmm6(6), tmm7(7) + // for my convenience + , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) + , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) + , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) + , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) + , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) + , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) + , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) + , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) + , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) + , rip() +#endif +#ifndef XBYAK_DISABLE_SEGMENT + , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) +#endif + , isDefaultJmpNEAR_(false) + , defaultEncoding_(EvexEncoding) + { + labelMgr_.set(this); + } + void reset() + { + ClearError(); + resetSize(); + labelMgr_.reset(); + labelMgr_.set(this); + } + bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } + /* + MUST call ready() to complete generating code if you use AutoGrow mode. + It is not necessary for the other mode if hasUndefinedLabel() is true. + */ + void ready(ProtectMode mode = PROTECT_RWE) + { + if (hasUndefinedLabel()) XBYAK_THROW(ERR_LABEL_IS_NOT_FOUND) + if (isAutoGrow()) { + calcJmpAddress(); + if (useProtect()) setProtectMode(mode); + } + } + // set read/exec + void readyRE() { return ready(PROTECT_RE); } +#ifdef XBYAK_TEST + void dump(bool doClear = true) + { + CodeArray::dump(); + if (doClear) size_ = 0; + } +#endif + +#ifdef XBYAK_UNDEF_JNL + #undef jnl +#endif + + // set default encoding to select Vex or Evex + void setDefaultEncoding(PreferredEncoding encoding) { defaultEncoding_ = encoding; } + + void sha1msg12(const Xmm& x, const Operand& op) + { + opROO(Reg(), op, x, T_MUST_EVEX, 0xD9); + } + void bswap(const Reg32e& r) + { + int idx = r.getIdx(); + uint8_t rex = (r.isREG(64) ? 8 : 0) | ((idx & 8) ? 1 : 0); + if (idx >= 16) { + db(0xD5); db((1<<7) | (idx & 16) | rex); + } else { + if (rex) db(0x40 | rex); + db(0x0F); + } + db(0xC8 + (idx & 7)); + } + /* + use single byte nop if useMultiByteNop = false + */ + void nop(size_t size = 1, bool useMultiByteNop = true) + { + if (!useMultiByteNop) { + for (size_t i = 0; i < size; i++) { + db(0x90); + } + return; + } + /* + Intel Architectures Software Developer's Manual Volume 2 + recommended multi-byte sequence of NOP instruction + AMD and Intel seem to agree on the same sequences for up to 9 bytes: + https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf + */ + static const uint8_t nopTbl[9][9] = { + {0x90}, + {0x66, 0x90}, + {0x0F, 0x1F, 0x00}, + {0x0F, 0x1F, 0x40, 0x00}, + {0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, + {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + }; + const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); + while (size > 0) { + size_t len = (std::min)(n, size); + const uint8_t *seq = nopTbl[len - 1]; + db(seq, len); + size -= len; + } + } +#ifndef XBYAK_DONT_READ_LIST +#include "xbyak_mnemonic.h" + /* + use single byte nop if useMultiByteNop = false + */ + void align(size_t x = 16, bool useMultiByteNop = true) + { + if (x == 1) return; + if (x < 1 || (x & (x - 1))) XBYAK_THROW(ERR_BAD_ALIGN) + if (isAutoGrow()) XBYAK_THROW(ERR_BAD_ALIGN) + size_t remain = size_t(getCurr()) % x; + if (remain) { + nop(x - remain, useMultiByteNop); + } + } +#endif +}; + +template <> +inline void CodeGenerator::mov(const NativeReg& reg, const char *label) // can't use std::string +{ + assert(label); + mov_imm(reg, dummyAddr); + putL(label); +} + +namespace util { +static const XBYAK_CONSTEXPR Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); +static const XBYAK_CONSTEXPR Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); +static const XBYAK_CONSTEXPR Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); +static const XBYAK_CONSTEXPR Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); +static const XBYAK_CONSTEXPR Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); +static const XBYAK_CONSTEXPR Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); +static const XBYAK_CONSTEXPR Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); +static const XBYAK_CONSTEXPR AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); +static const XBYAK_CONSTEXPR AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); +static const XBYAK_CONSTEXPR Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); +static const XBYAK_CONSTEXPR Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); +static const XBYAK_CONSTEXPR BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); +static const XBYAK_CONSTEXPR EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); +static const XBYAK_CONSTEXPR EvexModifierZero T_z; +#ifdef XBYAK64 +static const XBYAK_CONSTEXPR Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); +static const XBYAK_CONSTEXPR Reg64 r16(16), r17(17), r18(18), r19(19), r20(20), r21(21), r22(22), r23(23), r24(24), r25(25), r26(26), r27(27), r28(28), r29(29), r30(30), r31(31); +static const XBYAK_CONSTEXPR Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); +static const XBYAK_CONSTEXPR Reg32 r16d(16), r17d(17), r18d(18), r19d(19), r20d(20), r21d(21), r22d(22), r23d(23), r24d(24), r25d(25), r26d(26), r27d(27), r28d(28), r29d(29), r30d(30), r31d(31); +static const XBYAK_CONSTEXPR Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); +static const XBYAK_CONSTEXPR Reg16 r16w(16), r17w(17), r18w(18), r19w(19), r20w(20), r21w(21), r22w(22), r23w(23), r24w(24), r25w(25), r26w(26), r27w(27), r28w(28), r29w(29), r30w(30), r31w(31); +static const XBYAK_CONSTEXPR Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); +static const XBYAK_CONSTEXPR Reg8 r16b(16), r17b(17), r18b(18), r19b(19), r20b(20), r21b(21), r22b(22), r23b(23), r24b(24), r25b(25), r26b(26), r27b(27), r28b(28), r29b(29), r30b(30), r31b(31); +static const XBYAK_CONSTEXPR Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); +static const XBYAK_CONSTEXPR Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); +static const XBYAK_CONSTEXPR Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); +static const XBYAK_CONSTEXPR Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); +static const XBYAK_CONSTEXPR Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); +static const XBYAK_CONSTEXPR Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); +static const XBYAK_CONSTEXPR Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); +static const XBYAK_CONSTEXPR Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); +static const XBYAK_CONSTEXPR Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); +static const XBYAK_CONSTEXPR Zmm tmm0(0), tmm1(1), tmm2(2), tmm3(3), tmm4(4), tmm5(5), tmm6(6), tmm7(7); +static const XBYAK_CONSTEXPR RegRip rip; +static const XBYAK_CONSTEXPR ApxFlagNF T_nf; +static const XBYAK_CONSTEXPR ApxFlagZU T_zu; +#endif +#ifndef XBYAK_DISABLE_SEGMENT +static const XBYAK_CONSTEXPR Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); +#endif +} // util + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +#if defined(__GNUC__) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif + +} // end of namespace + +#endif // XBYAK_XBYAK_H_ diff --git a/addon/aocl_gemm/JIT/xbyak/xbyak_mnemonic.h b/addon/aocl_gemm/JIT/xbyak/xbyak_mnemonic.h new file mode 100644 index 0000000000..ac2a38fc20 --- /dev/null +++ b/addon/aocl_gemm/JIT/xbyak/xbyak_mnemonic.h @@ -0,0 +1,2582 @@ +const char *getVersionString() const { return "7.05"; } +void aadd(const Address& addr, const Reg32e ®) { opMR(addr, reg, T_0F38, 0x0FC, T_APX); } +void aand(const Address& addr, const Reg32e ®) { opMR(addr, reg, T_0F38|T_66, 0x0FC, T_APX|T_66); } +void adc(const Operand& op, uint32_t imm) { opOI(op, imm, 0x10, 2); } +void adc(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x10); } +void adc(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NONE, 2); } +void adc(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NONE, 0x10); } +void adcx(const Reg32e& d, const Reg32e& reg, const Operand& op) { opROO(d, op, reg, T_66, 0x66); } +void adcx(const Reg32e& reg, const Operand& op) { if (!reg.isREG(16|i32e) && reg.getBit() == op.getBit()) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) if (opROO(Reg(), op, reg, T_66, 0x66)) return; opRO(reg, op, T_66 | T_0F38, 0xF6); } +void add(const Operand& op, uint32_t imm) { opOI(op, imm, 0x00, 0); } +void add(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x00); } +void add(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NF|T_CODE1_IF1, 0); } +void add(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NF|T_CODE1_IF1, 0x00); } +void addpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x58, isXMM_XMMorMEM); } +void addps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x58, isXMM_XMMorMEM); } +void addsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x58, isXMM_XMMorMEM); } +void addss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x58, isXMM_XMMorMEM); } +void addsubpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F|T_YMM, 0xD0, isXMM_XMMorMEM); } +void addsubps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F2|T_0F|T_YMM, 0xD0, isXMM_XMMorMEM); } +void adox(const Reg32e& d, const Reg32e& reg, const Operand& op) { opROO(d, op, reg, T_F3, 0x66); } +void adox(const Reg32e& reg, const Operand& op) { if (!reg.isREG(16|i32e) && reg.getBit() == op.getBit()) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) if (opROO(Reg(), op, reg, T_F3, 0x66)) return; opRO(reg, op, T_F3 | T_0F38, 0xF6); } +void aesdec(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_YMM|T_EVEX, 0xDE, isXMM_XMMorMEM); } +void aesdeclast(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_YMM|T_EVEX, 0xDF, isXMM_XMMorMEM); } +void aesenc(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_YMM|T_EVEX, 0xDC, isXMM_XMMorMEM); } +void aesenclast(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_YMM|T_EVEX, 0xDD, isXMM_XMMorMEM); } +void aesimc(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_W0, 0xDB, isXMM_XMMorMEM, NONE); } +void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A, 0xDF, isXMM_XMMorMEM, imm); } +void and_(const Operand& op, uint32_t imm) { opOI(op, imm, 0x20, 4); } +void and_(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x20); } +void and_(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NF|T_CODE1_IF1, 4); } +void and_(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NF|T_CODE1_IF1, 0x20); } +void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opRRO(r1, r2, op, T_APX|T_0F38|T_NF, 0xf2); } +void andnpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x55, isXMM_XMMorMEM); } +void andnps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x55, isXMM_XMMorMEM); } +void andpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x54, isXMM_XMMorMEM); } +void andps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x54, isXMM_XMMorMEM); } +void aor(const Address& addr, const Reg32e ®) { opMR(addr, reg, T_0F38|T_F2, 0x0FC, T_APX|T_F2); } +void axor(const Address& addr, const Reg32e ®) { opMR(addr, reg, T_0F38|T_F3, 0x0FC, T_APX|T_F3); } +void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opRRO(r1, r2, op, T_APX|T_0F38|T_NF, 0xf7); } +void blendpd(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x0D, isXMM_XMMorMEM, static_cast(imm)); } +void blendps(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x0C, isXMM_XMMorMEM, static_cast(imm)); } +void blendvpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38, 0x15, isXMM_XMMorMEM, NONE); } +void blendvps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38, 0x14, isXMM_XMMorMEM, NONE); } +void blsi(const Reg32e& r, const Operand& op) { opRRO(Reg32e(3, r.getBit()), r, op, T_APX|T_0F38|T_NF, 0xf3); } +void blsmsk(const Reg32e& r, const Operand& op) { opRRO(Reg32e(2, r.getBit()), r, op, T_APX|T_0F38|T_NF, 0xf3); } +void blsr(const Reg32e& r, const Operand& op) { opRRO(Reg32e(1, r.getBit()), r, op, T_APX|T_0F38|T_NF, 0xf3); } +void bnd() { db(0xF2); } +void bndcl(const BoundsReg& bnd, const Operand& op) { opRext(op, i32e, bnd.getIdx(), T_F3 | T_0F, 0x1A, !op.isMEM()); } +void bndcn(const BoundsReg& bnd, const Operand& op) { opRext(op, i32e, bnd.getIdx(), T_F2 | T_0F, 0x1B, !op.isMEM()); } +void bndcu(const BoundsReg& bnd, const Operand& op) { opRext(op, i32e, bnd.getIdx(), T_F2 | T_0F, 0x1A, !op.isMEM()); } +void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, T_0F, 0x1A); } +void bndmk(const BoundsReg& bnd, const Address& addr) { opMR(addr, bnd, T_F3 | T_0F, 0x1B); } +void bndmov(const Address& addr, const BoundsReg& bnd) { opMR(addr, bnd, T_66 | T_0F, 0x1B); } +void bndmov(const BoundsReg& bnd, const Operand& op) { opRO(bnd, op, T_66 | T_0F, 0x1A, op.isBNDREG()); } +void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, T_0F, 0x1B); } +void bsf(const Reg®, const Operand& op) { opRO(reg, op, T_0F, 0xBC, op.isREG(16|i32e)); } +void bsr(const Reg®, const Operand& op) { opRO(reg, op, T_0F, 0xBD, op.isREG(16|i32e)); } +void bt(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xA3, op.isREG(16|i32e) && op.getBit() == reg.getBit()); } +void bt(const Operand& op, uint8_t imm) { opRext(op, 16|i32e, 4, T_0F, 0xba, false, 1); db(imm); } +void btc(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xBB, op.isREG(16|i32e) && op.getBit() == reg.getBit()); } +void btc(const Operand& op, uint8_t imm) { opRext(op, 16|i32e, 7, T_0F, 0xba, false, 1); db(imm); } +void btr(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xB3, op.isREG(16|i32e) && op.getBit() == reg.getBit()); } +void btr(const Operand& op, uint8_t imm) { opRext(op, 16|i32e, 6, T_0F, 0xba, false, 1); db(imm); } +void bts(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xAB, op.isREG(16|i32e) && op.getBit() == reg.getBit()); } +void bts(const Operand& op, uint8_t imm) { opRext(op, 16|i32e, 5, T_0F, 0xba, false, 1); db(imm); } +void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opRRO(r1, r2, op, T_APX|T_0F38|T_NF, 0xf5); } +void cbw() { db(0x66); db(0x98); } +void ccmpa(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 7); } +void ccmpa(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 7); } +void ccmpae(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 3); } +void ccmpae(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 3); } +void ccmpb(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 2); } +void ccmpb(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 2); } +void ccmpbe(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 6); } +void ccmpbe(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 6); } +void ccmpc(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 2); } +void ccmpc(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 2); } +void ccmpe(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 4); } +void ccmpe(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 4); } +void ccmpf(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 11); } +void ccmpf(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 11); } +void ccmpg(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 15); } +void ccmpg(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 15); } +void ccmpge(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 13); } +void ccmpge(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 13); } +void ccmpl(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 12); } +void ccmpl(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 12); } +void ccmple(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 14); } +void ccmple(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 14); } +void ccmpna(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 6); } +void ccmpna(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 6); } +void ccmpnae(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 2); } +void ccmpnae(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 2); } +void ccmpnb(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 3); } +void ccmpnb(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 3); } +void ccmpnbe(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 7); } +void ccmpnbe(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 7); } +void ccmpnc(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 3); } +void ccmpnc(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 3); } +void ccmpne(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 5); } +void ccmpne(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 5); } +void ccmpng(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 14); } +void ccmpng(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 14); } +void ccmpnge(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 12); } +void ccmpnge(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 12); } +void ccmpnl(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 13); } +void ccmpnl(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 13); } +void ccmpnle(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 15); } +void ccmpnle(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 15); } +void ccmpno(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 1); } +void ccmpno(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 1); } +void ccmpns(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 9); } +void ccmpns(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 9); } +void ccmpnz(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 5); } +void ccmpnz(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 5); } +void ccmpo(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 0); } +void ccmpo(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 0); } +void ccmps(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 8); } +void ccmps(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 8); } +void ccmpt(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 10); } +void ccmpt(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 10); } +void ccmpz(const Operand& op, int imm, int dfv = 0) { opCcmpi(op, imm, dfv, 4); } +void ccmpz(const Operand& op1, const Operand& op2, int dfv = 0) { opCcmp(op1, op2, dfv, 0x38, 4); } +void cdq() { db(0x99); } +void cfcmovb(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x42); } +void cfcmovb(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x42); } +void cfcmovbe(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x46); } +void cfcmovbe(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x46); } +void cfcmovl(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4C); } +void cfcmovl(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4C); } +void cfcmovle(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4E); } +void cfcmovle(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4E); } +void cfcmovnb(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x43); } +void cfcmovnb(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x43); } +void cfcmovnbe(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x47); } +void cfcmovnbe(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x47); } +void cfcmovnl(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4D); } +void cfcmovnl(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4D); } +void cfcmovnle(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4F); } +void cfcmovnle(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4F); } +void cfcmovno(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x41); } +void cfcmovno(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x41); } +void cfcmovnp(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4B); } +void cfcmovnp(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4B); } +void cfcmovns(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x49); } +void cfcmovns(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x49); } +void cfcmovnz(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x45); } +void cfcmovnz(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x45); } +void cfcmovo(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x40); } +void cfcmovo(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x40); } +void cfcmovp(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x4A); } +void cfcmovp(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x4A); } +void cfcmovs(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x48); } +void cfcmovs(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x48); } +void cfcmovz(const Operand& op1, const Operand& op2) { opCfcmov(Reg(), op1, op2, 0x44); } +void cfcmovz(const Reg& d, const Reg& r, const Operand& op) { opCfcmov(d|T_nf, op, r, 0x44); } +void clc() { db(0xF8); } +void cld() { db(0xFC); } +void cldemote(const Address& addr) { opMR(addr, eax, T_0F, 0x1C); } +void clflush(const Address& addr) { opMR(addr, Reg32(7), T_0F, 0xAE); } +void clflushopt(const Address& addr) { opMR(addr, Reg32(7), T_66 | T_0F, 0xAE); } +void cli() { db(0xFA); } +void clwb(const Address& addr) { opMR(addr, esi, T_66 | T_0F, 0xAE); } +void clzero() { db(0x0F); db(0x01); db(0xFC); } +void cmc() { db(0xF5); } +void cmova(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 7); }//-V524 +void cmova(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 7, op.isREG(16|i32e)); }//-V524 +void cmovae(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 3); }//-V524 +void cmovae(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 3, op.isREG(16|i32e)); }//-V524 +void cmovb(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 2); }//-V524 +void cmovb(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 2, op.isREG(16|i32e)); }//-V524 +void cmovbe(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 6); }//-V524 +void cmovbe(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 6, op.isREG(16|i32e)); }//-V524 +void cmovc(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 2); }//-V524 +void cmovc(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 2, op.isREG(16|i32e)); }//-V524 +void cmove(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 4); }//-V524 +void cmove(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 4, op.isREG(16|i32e)); }//-V524 +void cmovg(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 15); }//-V524 +void cmovg(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 15, op.isREG(16|i32e)); }//-V524 +void cmovge(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 13); }//-V524 +void cmovge(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 13, op.isREG(16|i32e)); }//-V524 +void cmovl(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 12); }//-V524 +void cmovl(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 12, op.isREG(16|i32e)); }//-V524 +void cmovle(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 14); }//-V524 +void cmovle(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 14, op.isREG(16|i32e)); }//-V524 +void cmovna(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 6); }//-V524 +void cmovna(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 6, op.isREG(16|i32e)); }//-V524 +void cmovnae(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 2); }//-V524 +void cmovnae(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 2, op.isREG(16|i32e)); }//-V524 +void cmovnb(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 3); }//-V524 +void cmovnb(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 3, op.isREG(16|i32e)); }//-V524 +void cmovnbe(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 7); }//-V524 +void cmovnbe(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 7, op.isREG(16|i32e)); }//-V524 +void cmovnc(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 3); }//-V524 +void cmovnc(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 3, op.isREG(16|i32e)); }//-V524 +void cmovne(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 5); }//-V524 +void cmovne(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 5, op.isREG(16|i32e)); }//-V524 +void cmovng(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 14); }//-V524 +void cmovng(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 14, op.isREG(16|i32e)); }//-V524 +void cmovnge(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 12); }//-V524 +void cmovnge(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 12, op.isREG(16|i32e)); }//-V524 +void cmovnl(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 13); }//-V524 +void cmovnl(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 13, op.isREG(16|i32e)); }//-V524 +void cmovnle(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 15); }//-V524 +void cmovnle(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 15, op.isREG(16|i32e)); }//-V524 +void cmovno(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 1); }//-V524 +void cmovno(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 1, op.isREG(16|i32e)); }//-V524 +void cmovnp(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 11); }//-V524 +void cmovnp(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 11, op.isREG(16|i32e)); }//-V524 +void cmovns(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 9); }//-V524 +void cmovns(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 9, op.isREG(16|i32e)); }//-V524 +void cmovnz(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 5); }//-V524 +void cmovnz(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 5, op.isREG(16|i32e)); }//-V524 +void cmovo(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 0); }//-V524 +void cmovo(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 0, op.isREG(16|i32e)); }//-V524 +void cmovp(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 10); }//-V524 +void cmovp(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 10, op.isREG(16|i32e)); }//-V524 +void cmovpe(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 10); }//-V524 +void cmovpe(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 10, op.isREG(16|i32e)); }//-V524 +void cmovpo(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 11); }//-V524 +void cmovpo(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 11, op.isREG(16|i32e)); }//-V524 +void cmovs(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 8); }//-V524 +void cmovs(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 8, op.isREG(16|i32e)); }//-V524 +void cmovz(const Reg& d, const Reg& reg, const Operand& op) { opROO(d, op, reg, T_APX|T_ND1, 0x40 | 4); }//-V524 +void cmovz(const Reg& reg, const Operand& op) { opRO(reg, op, T_0F, 0x40 | 4, op.isREG(16|i32e)); }//-V524 +void cmp(const Operand& op, uint32_t imm) { opOI(op, imm, 0x38, 7); } +void cmp(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x38); } +void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } +void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } +void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } +void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } +void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } +void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } +void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } +void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } +void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } +void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } +void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } +void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } +void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } +void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } +void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } +void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } +void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } +void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } +void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } +void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } +void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } +void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } +void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } +void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } +void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } +void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } +void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } +void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } +void cmppd(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F | T_66, 0xC2, isXMM_XMMorMEM, imm8); } +void cmpps(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F, 0xC2, isXMM_XMMorMEM, imm8); } +void cmpsb() { db(0xA6); } +void cmpsd() { db(0xA7); } +void cmpsd(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F | T_F2, 0xC2, isXMM_XMMorMEM, imm8); } +void cmpss(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F | T_F3, 0xC2, isXMM_XMMorMEM, imm8); } +void cmpsw() { db(0x66); db(0xA7); } +void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } +void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } +void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } +void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } +void cmpxchg(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xB0 | (reg.isBit(8) ? 0 : 1), op.getBit() == reg.getBit()); } +void cmpxchg8b(const Address& addr) { opMR(addr, Reg32(1), T_0F, 0xC7); } +void comisd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x2F, isXMM_XMMorMEM); } +void comiss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x2F, isXMM_XMMorMEM); } +void cpuid() { db(0x0F); db(0xA2); } +void crc32(const Reg32e& r, const Operand& op) { if (!((r.isBit(32) && op.isBit(8|16|32)) || (r.isBit(64) && op.isBit(8|64)))) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) int code = 0xF0 | (op.isBit(8) ? 0 : 1); uint64_t type = op.isBit(16) ? T_66:0; if (opROO(Reg(), op, static_cast(r), T_APX|type, code)) return; opRO(r, op, T_F2|T_0F38|type, code); } +void ctesta(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 7); } +void ctesta(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 7); } +void ctestae(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 3); } +void ctestae(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 3); } +void ctestb(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 2); } +void ctestb(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 2); } +void ctestbe(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 6); } +void ctestbe(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 6); } +void ctestc(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 2); } +void ctestc(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 2); } +void cteste(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 4); } +void cteste(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 4); } +void ctestf(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 11); } +void ctestf(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 11); } +void ctestg(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 15); } +void ctestg(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 15); } +void ctestge(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 13); } +void ctestge(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 13); } +void ctestl(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 12); } +void ctestl(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 12); } +void ctestle(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 14); } +void ctestle(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 14); } +void ctestna(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 6); } +void ctestna(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 6); } +void ctestnae(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 2); } +void ctestnae(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 2); } +void ctestnb(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 3); } +void ctestnb(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 3); } +void ctestnbe(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 7); } +void ctestnbe(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 7); } +void ctestnc(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 3); } +void ctestnc(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 3); } +void ctestne(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 5); } +void ctestne(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 5); } +void ctestng(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 14); } +void ctestng(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 14); } +void ctestnge(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 12); } +void ctestnge(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 12); } +void ctestnl(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 13); } +void ctestnl(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 13); } +void ctestnle(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 15); } +void ctestnle(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 15); } +void ctestno(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 1); } +void ctestno(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 1); } +void ctestns(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 9); } +void ctestns(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 9); } +void ctestnz(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 5); } +void ctestnz(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 5); } +void ctesto(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 0); } +void ctesto(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 0); } +void ctests(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 8); } +void ctests(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 8); } +void ctestt(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 10); } +void ctestt(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 10); } +void ctestz(const Operand& op, const Reg& r, int dfv = 0) { opCcmp(op, r, dfv, 0x84, 4); } +void ctestz(const Operand& op, int imm, int dfv = 0) { opTesti(op, imm, dfv, 4); } +void cvtdq2pd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F3|T_0F, 0xE6, isXMM_XMMorMEM); } +void cvtdq2ps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5B, isXMM_XMMorMEM); } +void cvtpd2dq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F2|T_0F, 0xE6, isXMM_XMMorMEM); } +void cvtpd2pi(const Reg& reg, const Operand& op) { opSSE(reg, op, T_66|T_0F, 0x2D, isMMX_XMMorMEM); } +void cvtpd2ps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x5A, isXMM_XMMorMEM); } +void cvtpi2pd(const Reg& reg, const Operand& op) { opSSE(reg, op, T_66|T_0F, 0x2A, isXMM_MMXorMEM); } +void cvtpi2ps(const Reg& reg, const Operand& op) { opSSE(reg, op, T_0F, 0x2A, isXMM_MMXorMEM); } +void cvtps2dq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x5B, isXMM_XMMorMEM); } +void cvtps2pd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5A, isXMM_XMMorMEM); } +void cvtps2pi(const Reg& reg, const Operand& op) { opSSE(reg, op, T_0F, 0x2D, isMMX_XMMorMEM); } +void cvtsd2si(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F2|T_0F, 0x2D, isREG32_XMMorMEM); } +void cvtsd2ss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F2|T_0F, 0x5A, isXMM_XMMorMEM); } +void cvtsi2sd(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F2|T_0F, 0x2A, isXMM_REG32orMEM); } +void cvtsi2ss(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F3|T_0F, 0x2A, isXMM_REG32orMEM); } +void cvtss2sd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F3|T_0F, 0x5A, isXMM_XMMorMEM); } +void cvtss2si(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F3|T_0F, 0x2D, isREG32_XMMorMEM); } +void cvttpd2dq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0xE6, isXMM_XMMorMEM); } +void cvttpd2pi(const Reg& reg, const Operand& op) { opSSE(reg, op, T_66|T_0F, 0x2C, isMMX_XMMorMEM); } +void cvttps2dq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F3|T_0F, 0x5B, isXMM_XMMorMEM); } +void cvttps2pi(const Reg& reg, const Operand& op) { opSSE(reg, op, T_0F, 0x2C, isMMX_XMMorMEM); } +void cvttsd2si(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F2|T_0F, 0x2C, isREG32_XMMorMEM); } +void cvttss2si(const Reg& reg, const Operand& op) { opSSE(reg, op, T_F3|T_0F, 0x2C, isREG32_XMMorMEM); } +void cwd() { db(0x66); db(0x99); } +void cwde() { db(0x98); } +void dec(const Operand& op) { opIncDec(Reg(), op, 1); } +void dec(const Reg& d, const Operand& op) { opIncDec(d, op, 1); } +void div(const Operand& op) { opRext(op, 0, 6, T_APX|T_NF|T_CODE1_IF1, 0xF6); } +void divpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x5E, isXMM_XMMorMEM); } +void divps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5E, isXMM_XMMorMEM); } +void divsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x5E, isXMM_XMMorMEM); } +void divss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x5E, isXMM_XMMorMEM); } +void dppd(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x41, isXMM_XMMorMEM, static_cast(imm)); } +void dpps(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x40, isXMM_XMMorMEM, static_cast(imm)); } +void emms() { db(0x0F); db(0x77); } +void endbr32() { db(0xF3); db(0x0F); db(0x1E); db(0xFB); } +void endbr64() { db(0xF3); db(0x0F); db(0x1E); db(0xFA); } +void enter(uint16_t x, uint8_t y) { db(0xC8); dw(x); db(y); } +void extractps(const Operand& op, const Xmm& xmm, uint8_t imm) { opExt(op, xmm, 0x17, imm); } +void f2xm1() { db(0xD9); db(0xF0); } +void fabs() { db(0xD9); db(0xE1); } +void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } +void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } +void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } +void faddp() { db(0xDE); db(0xC1); } +void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } +void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } +void fbld(const Address& addr) { opMR(addr, Reg32(4), 0, 0xDF); } +void fbstp(const Address& addr) { opMR(addr, Reg32(6), 0, 0xDF); } +void fchs() { db(0xD9); db(0xE0); } +void fclex() { db(0x9B); db(0xDB); db(0xE2); } +void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } +void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } +void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } +void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } +void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } +void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } +void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } +void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } +void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } +void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } +void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } +void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } +void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } +void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } +void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } +void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } +void fcom() { db(0xD8); db(0xD1); } +void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } +void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } +void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } +void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } +void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } +void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } +void fcomp() { db(0xD8); db(0xD9); } +void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } +void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } +void fcompp() { db(0xDE); db(0xD9); } +void fcos() { db(0xD9); db(0xFF); } +void fdecstp() { db(0xD9); db(0xF6); } +void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } +void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } +void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } +void fdivp() { db(0xDE); db(0xF9); } +void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } +void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } +void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } +void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } +void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } +void fdivrp() { db(0xDE); db(0xF1); } +void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } +void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } +void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } +void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } +void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } +void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } +void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } +void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } +void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } +void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } +void fincstp() { db(0xD9); db(0xF7); } +void finit() { db(0x9B); db(0xDB); db(0xE3); } +void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } +void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } +void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } +void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } +void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } +void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } +void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } +void fld1() { db(0xD9); db(0xE8); } +void fldcw(const Address& addr) { opMR(addr, Reg32(5), 0, 0xD9); } +void fldenv(const Address& addr) { opMR(addr, Reg32(4), 0, 0xD9); } +void fldl2e() { db(0xD9); db(0xEA); } +void fldl2t() { db(0xD9); db(0xE9); } +void fldlg2() { db(0xD9); db(0xEC); } +void fldln2() { db(0xD9); db(0xED); } +void fldpi() { db(0xD9); db(0xEB); } +void fldz() { db(0xD9); db(0xEE); } +void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } +void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } +void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } +void fmulp() { db(0xDE); db(0xC9); } +void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } +void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } +void fnclex() { db(0xDB); db(0xE2); } +void fninit() { db(0xDB); db(0xE3); } +void fnop() { db(0xD9); db(0xD0); } +void fnsave(const Address& addr) { opMR(addr, Reg32(6), 0, 0xDD); } +void fnstcw(const Address& addr) { opMR(addr, Reg32(7), 0, 0xD9); } +void fnstenv(const Address& addr) { opMR(addr, Reg32(6), 0, 0xD9); } +void fnstsw(const Address& addr) { opMR(addr, Reg32(7), 0, 0xDD); } +void fnstsw(const Reg16& r) { if (r.getIdx() != Operand::AX) XBYAK_THROW(ERR_BAD_PARAMETER) db(0xDF); db(0xE0); } +void fpatan() { db(0xD9); db(0xF3); } +void fprem() { db(0xD9); db(0xF8); } +void fprem1() { db(0xD9); db(0xF5); } +void fptan() { db(0xD9); db(0xF2); } +void frndint() { db(0xD9); db(0xFC); } +void frstor(const Address& addr) { opMR(addr, Reg32(4), 0, 0xDD); } +void fsave(const Address& addr) { db(0x9B); opMR(addr, Reg32(6), 0, 0xDD); } +void fscale() { db(0xD9); db(0xFD); } +void fsin() { db(0xD9); db(0xFE); } +void fsincos() { db(0xD9); db(0xFB); } +void fsqrt() { db(0xD9); db(0xFA); } +void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } +void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } +void fstcw(const Address& addr) { db(0x9B); opMR(addr, Reg32(7), 0, 0xD9); } +void fstenv(const Address& addr) { db(0x9B); opMR(addr, Reg32(6), 0, 0xD9); } +void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } +void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } +void fstsw(const Address& addr) { db(0x9B); opMR(addr, Reg32(7), 0, 0xDD); } +void fstsw(const Reg16& r) { if (r.getIdx() != Operand::AX) XBYAK_THROW(ERR_BAD_PARAMETER) db(0x9B); db(0xDF); db(0xE0); } +void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } +void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } +void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } +void fsubp() { db(0xDE); db(0xE9); } +void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } +void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } +void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } +void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } +void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } +void fsubrp() { db(0xDE); db(0xE1); } +void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } +void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } +void ftst() { db(0xD9); db(0xE4); } +void fucom() { db(0xDD); db(0xE1); } +void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } +void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } +void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } +void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } +void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } +void fucomp() { db(0xDD); db(0xE9); } +void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } +void fucompp() { db(0xDA); db(0xE9); } +void fwait() { db(0x9B); } +void fxam() { db(0xD9); db(0xE5); } +void fxch() { db(0xD9); db(0xC9); } +void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } +void fxrstor(const Address& addr) { opMR(addr, Reg32(1), T_0F, 0xAE); } +void fxtract() { db(0xD9); db(0xF4); } +void fyl2x() { db(0xD9); db(0xF1); } +void fyl2xp1() { db(0xD9); db(0xF9); } +void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0xCF, isXMM_XMMorMEM, static_cast(imm)); } +void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0xCE, isXMM_XMMorMEM, static_cast(imm)); } +void gf2p8mulb(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0xCF, isXMM_XMMorMEM); } +void haddpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F|T_YMM, 0x7C, isXMM_XMMorMEM); } +void haddps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F2|T_0F|T_YMM, 0x7C, isXMM_XMMorMEM); } +void hlt() { db(0xF4); } +void hsubpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F|T_YMM, 0x7D, isXMM_XMMorMEM); } +void hsubps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F2|T_0F|T_YMM, 0x7D, isXMM_XMMorMEM); } +void idiv(const Operand& op) { opRext(op, 0, 7, T_APX|T_NF|T_CODE1_IF1, 0xF6); } +void imul(const Operand& op) { opRext(op, 0, 5, T_APX|T_NF|T_CODE1_IF1, 0xF6); } +void imul(const Reg& reg, const Operand& op) { if (opROO(Reg(), op, reg, T_APX|T_NF, 0xAF)) return; opRO(reg, op, T_0F, 0xAF, reg.getKind() == op.getKind()); } +void in_(const Reg& a, const Reg& d) { opInOut(a, d, 0xEC); } +void in_(const Reg& a, uint8_t v) { opInOut(a, 0xE4, v); } +void inc(const Operand& op) { opIncDec(Reg(), op, 0); } +void inc(const Reg& d, const Operand& op) { opIncDec(d, op, 0); } +void insertps(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x21, isXMM_XMMorMEM, imm); } +void int3() { db(0xCC); } +void int_(uint8_t x) { db(0xCD); db(x); } +void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 +void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 +void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 +void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 +void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 +void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 +void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 +void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 +void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 +void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 +void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 +void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 +void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 +void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 +void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 +void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 +void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 +void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 +void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 +void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 +void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 +void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 +void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 +void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 +void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 +void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 +void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 +void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 +void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 +void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 +void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 +void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 +void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 +void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 +void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void lahf() { db(0x9F); } +void lddqu(const Xmm& xmm, const Address& addr) { opMR(addr, xmm, T_F2 | T_0F, 0xF0); } +void ldmxcsr(const Address& addr) { opMR(addr, Reg32(2), T_0F, 0xAE); } +void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) opMR(addr, reg, 0, 0x8D); } +void leave() { db(0xC9); } +void lfence() { db(0x0F); db(0xAE); db(0xE8); } +void lfs(const Reg& reg, const Address& addr) { opLoadSeg(addr, reg, T_0F, 0xB4); } +void lgs(const Reg& reg, const Address& addr) { opLoadSeg(addr, reg, T_0F, 0xB5); } +void lock() { db(0xF0); } +void lodsb() { db(0xAC); } +void lodsd() { db(0xAD); } +void lodsw() { db(0x66); db(0xAD); } +void loop(const Label& label) { opJmp(label, T_SHORT, 0xE2, 0, 0); } +void loop(const char *label) { loop(std::string(label)); } +void loop(std::string label) { opJmp(label, T_SHORT, 0xE2, 0, 0); } +void loope(const Label& label) { opJmp(label, T_SHORT, 0xE1, 0, 0); } +void loope(const char *label) { loope(std::string(label)); } +void loope(std::string label) { opJmp(label, T_SHORT, 0xE1, 0, 0); } +void loopne(const Label& label) { opJmp(label, T_SHORT, 0xE0, 0, 0); } +void loopne(const char *label) { loopne(std::string(label)); } +void loopne(std::string label) { opJmp(label, T_SHORT, 0xE0, 0, 0); } +void lss(const Reg& reg, const Address& addr) { opLoadSeg(addr, reg, T_0F, 0xB2); } +void lzcnt(const Reg®, const Operand& op) { if (opROO(Reg(), op, reg, T_APX|T_NF, 0xF5)) return; opCnt(reg, op, 0xBD); } +void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { opRR(reg1, reg2, T_66|T_0F, 0xF7); } +void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) XBYAK_THROW(ERR_BAD_COMBINATION) opRR(reg1, reg2, T_0F, 0xF7); } +void maxpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x5F, isXMM_XMMorMEM); } +void maxps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5F, isXMM_XMMorMEM); } +void maxsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x5F, isXMM_XMMorMEM); } +void maxss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x5F, isXMM_XMMorMEM); } +void mfence() { db(0x0F); db(0xAE); db(0xF0); } +void minpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x5D, isXMM_XMMorMEM); } +void minps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5D, isXMM_XMMorMEM); } +void minsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x5D, isXMM_XMMorMEM); } +void minss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x5D, isXMM_XMMorMEM); } +void monitor() { db(0x0F); db(0x01); db(0xC8); } +void monitorx() { db(0x0F); db(0x01); db(0xFA); } +void movapd(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_66, 0x29); } +void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, T_0F, T_66); } +void movaps(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_NONE, 0x29); } +void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, T_0F, T_NONE); } +void movbe(const Address& addr, const Reg& reg) { opMR(addr, reg, T_0F38, 0xF1, T_APX, 0x61); } +void movbe(const Reg& reg, const Address& addr) { opMR(addr, reg, T_0F38, 0xF0, T_APX, 0x60); } +void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opMR(addr, mmx, T_0F, 0x7E); } +void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opMR(addr, mmx, T_0F, 0x6E); } +void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opRR(mmx, reg, T_0F, 0x6E); } +void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opRR(mmx, reg, T_0F, 0x7E); } +void movddup(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_DUP|T_F2|T_0F|T_EW1|T_YMM|T_EVEX|T_ER_X|T_ER_Y|T_ER_Z, 0x12, isXMM_XMMorMEM, NONE); } +void movdir64b(const Reg& reg, const Address& addr) { opMR(addr, reg.cvt32(), T_66|T_0F38, 0xF8, T_APX|T_66); } +void movdiri(const Address& addr, const Reg32e& reg) { opMR(addr, reg, T_0F38, 0xF9, T_APX); } +void movdq2q(const Mmx& mmx, const Xmm& xmm) { opRR(mmx, xmm, T_F2 | T_0F, 0xD6); } +void movdqa(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_66, 0x7F); } +void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, T_0F, T_66); } +void movdqu(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_F3, 0x7F); } +void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, T_0F, T_F3); } +void movhlps(const Xmm& reg1, const Xmm& reg2) { opRR(reg1, reg2, T_0F, 0x12); } +void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, T_66|T_0F, 0x16); } +void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, T_0F, 0x16); } +void movlhps(const Xmm& reg1, const Xmm& reg2) { opRR(reg1, reg2, T_0F, 0x16); } +void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, T_66|T_0F, 0x12); } +void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, T_0F, 0x12); } +void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } +void movmskps(const Reg32e& reg, const Xmm& xmm) { opRR(reg, xmm, T_0F, 0x50); } +void movntdq(const Address& addr, const Xmm& reg) { opMR(addr, Reg16(reg.getIdx()), T_0F, 0xE7); } +void movntdqa(const Xmm& xmm, const Address& addr) { opMR(addr, xmm, T_66 | T_0F38, 0x2A); } +void movnti(const Address& addr, const Reg32e& reg) { opMR(addr, reg, T_0F, 0xC3); } +void movntpd(const Address& addr, const Xmm& reg) { opMR(addr, Reg16(reg.getIdx()), T_0F, 0x2B); } +void movntps(const Address& addr, const Xmm& xmm) { opMR(addr, Mmx(xmm.getIdx()), T_0F, 0x2B); } +void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) XBYAK_THROW(ERR_BAD_COMBINATION) opMR(addr, mmx, T_0F, 0xE7); } +void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opMR(addr, mmx, T_0F, mmx.isXMM() ? 0xD6 : 0x7F); } +void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opRO(mmx, op, T_0F, mmx.isXMM() ? 0x7E : 0x6F, mmx.getKind() == op.getKind()); } +void movq2dq(const Xmm& xmm, const Mmx& mmx) { opRR(xmm, mmx, T_F3 | T_0F, 0xD6); } +void movsb() { db(0xA4); } +void movsd() { db(0xA5); } +void movsd(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_F2, 0x11); } +void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, T_0F, T_F2); } +void movshdup(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F3|T_0F|T_EW0|T_YMM|T_EVEX, 0x16, isXMM_XMMorMEM, NONE); } +void movsldup(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_F3|T_0F|T_EW0|T_YMM|T_EVEX, 0x12, isXMM_XMMorMEM, NONE); } +void movss(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_F3, 0x11); } +void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, T_0F, T_F3); } +void movsw() { db(0x66); db(0xA5); } +void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } +void movupd(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_66, 0x11); } +void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, T_0F, T_66); } +void movups(const Address& addr, const Xmm& xmm) { opMR(addr, xmm, T_0F|T_NONE, 0x11); } +void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, T_0F, T_NONE); } +void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } +void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x42, isXMM_XMMorMEM, static_cast(imm)); } +void mul(const Operand& op) { opRext(op, 0, 4, T_APX|T_NF|T_CODE1_IF1, 0xF6); } +void mulpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x59, isXMM_XMMorMEM); } +void mulps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x59, isXMM_XMMorMEM); } +void mulsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x59, isXMM_XMMorMEM); } +void mulss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x59, isXMM_XMMorMEM); } +void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opRRO(r1, r2, op, T_APX|T_F2|T_0F38, 0xf6); } +void mwait() { db(0x0F); db(0x01); db(0xC9); } +void mwaitx() { db(0x0F); db(0x01); db(0xFB); } +void neg(const Operand& op) { opRext(op, 0, 3, T_APX|T_NF|T_CODE1_IF1, 0xF6); } +void neg(const Reg& d, const Operand& op) { opROO(d, op, Reg(3, Operand::REG, d.getBit()), T_APX|T_NF|T_CODE1_IF1|T_ND1, 0xF6); } +void not_(const Operand& op) { opRext(op, 0, 2, T_APX|T_CODE1_IF1, 0xF6); } +void not_(const Reg& d, const Operand& op) { opROO(d, op, Reg(2, Operand::REG, d.getBit()), T_APX|T_CODE1_IF1|T_ND1, 0xF6); } +void or_(const Operand& op, uint32_t imm) { opOI(op, imm, 0x08, 1); } +void or_(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x08); } +void or_(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NF|T_CODE1_IF1, 1); } +void or_(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NF|T_CODE1_IF1, 0x08); } +void orpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x56, isXMM_XMMorMEM); } +void orps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x56, isXMM_XMMorMEM); } +void out_(const Reg& d, const Reg& a) { opInOut(a, d, 0xEE); } +void out_(uint8_t v, const Reg& a) { opInOut(a, 0xE6, v); } +void outsb() { db(0x6E); } +void outsd() { db(0x6F); } +void outsw() { db(0x66); db(0x6F); } +void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, T_0F38, T_66); } +void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, T_0F38, T_66); } +void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, T_0F38, T_66); } +void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } +void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } +void packusdw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x2B, isXMM_XMMorMEM); } +void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } +void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } +void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } +void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } +void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } +void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } +void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } +void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } +void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } +void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0F, T_0F3A, T_66, static_cast(imm)); } +void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } +void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } +void pause() { db(0xF3); db(0x90); } +void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } +void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } +void pblendvb(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38, 0x10, isXMM_XMMorMEM, NONE); } +void pblendw(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x0E, isXMM_XMMorMEM, static_cast(imm)); } +void pclmulhqhqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } +void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } +void pclmullqhqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } +void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } +void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x44, isXMM_XMMorMEM, static_cast(imm)); } +void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } +void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } +void pcmpeqq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x29, isXMM_XMMorMEM); } +void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } +void pcmpestri(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A, 0x61, isXMM_XMMorMEM, imm); } +void pcmpestrm(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A, 0x60, isXMM_XMMorMEM, imm); } +void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } +void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } +void pcmpgtq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x37, isXMM_XMMorMEM); } +void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } +void pcmpistri(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A, 0x63, isXMM_XMMorMEM, imm); } +void pcmpistrm(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A, 0x62, isXMM_XMMorMEM, imm); } +void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opRRO(r1, r2, op, T_APX|T_F2|T_0F38, 0xf5); } +void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opRRO(r1, r2, op, T_APX|T_F3|T_0F38, 0xf5); } +void pextrb(const Operand& op, const Xmm& xmm, uint8_t imm) { opExt(op, xmm, 0x14, imm); } +void pextrd(const Operand& op, const Xmm& xmm, uint8_t imm) { opExt(op, xmm, 0x16, imm); } +void pextrw(const Operand& op, const Mmx& xmm, uint8_t imm) { opExt(op, xmm, 0x15, imm, true); } +void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, T_0F38, T_66); } +void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, T_0F38, T_66); } +void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, T_0F38, T_66); } +void phminposuw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38, 0x41, isXMM_XMMorMEM, NONE); } +void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, T_0F38, T_66); } +void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, T_0F38, T_66); } +void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, T_0F38, T_66); } +void pinsrb(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x20, isXMM_REG32orMEM, imm); } +void pinsrd(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x22, isXMM_REG32orMEM, imm); } +void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opSSE(mmx, op, T_0F | (mmx.isXMM() ? T_66 : T_NONE), 0xC4, 0, imm); } +void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, T_0F38, T_66); } +void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } +void pmaxsb(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3C, isXMM_XMMorMEM); } +void pmaxsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3D, isXMM_XMMorMEM); } +void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } +void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } +void pmaxud(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3F, isXMM_XMMorMEM); } +void pmaxuw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3E, isXMM_XMMorMEM); } +void pminsb(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x38, isXMM_XMMorMEM); } +void pminsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x39, isXMM_XMMorMEM); } +void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } +void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } +void pminud(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3B, isXMM_XMMorMEM); } +void pminuw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x3A, isXMM_XMMorMEM); } +void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opRR(reg, mmx, T_0F, 0xD7); } +void pmovsxbd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x21, isXMM_XMMorMEM, NONE); } +void pmovsxbq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N2|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x22, isXMM_XMMorMEM, NONE); } +void pmovsxbw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x20, isXMM_XMMorMEM, NONE); } +void pmovsxdq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_EW0|T_YMM|T_EVEX, 0x25, isXMM_XMMorMEM, NONE); } +void pmovsxwd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x23, isXMM_XMMorMEM, NONE); } +void pmovsxwq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x24, isXMM_XMMorMEM, NONE); } +void pmovzxbd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x31, isXMM_XMMorMEM, NONE); } +void pmovzxbq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N2|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x32, isXMM_XMMorMEM, NONE); } +void pmovzxbw(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x30, isXMM_XMMorMEM, NONE); } +void pmovzxdq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_EW0|T_YMM|T_EVEX, 0x35, isXMM_XMMorMEM, NONE); } +void pmovzxwd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x33, isXMM_XMMorMEM, NONE); } +void pmovzxwq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x34, isXMM_XMMorMEM, NONE); } +void pmuldq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x28, isXMM_XMMorMEM); } +void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, T_0F38, T_66); } +void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } +void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } +void pmulld(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66 | T_0F38, 0x40, isXMM_XMMorMEM); } +void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } +void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } +void popcnt(const Reg®, const Operand& op) { opCnt(reg, op, 0xB8); } +void popf() { db(0x9D); } +void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } +void prefetchit0(const Address& addr) { opMR(addr, Reg32(7), T_0F, 0x18); } +void prefetchit1(const Address& addr) { opMR(addr, Reg32(6), T_0F, 0x18); } +void prefetchnta(const Address& addr) { opMR(addr, Reg32(0), T_0F, 0x18); } +void prefetcht0(const Address& addr) { opMR(addr, Reg32(1), T_0F, 0x18); } +void prefetcht1(const Address& addr) { opMR(addr, Reg32(2), T_0F, 0x18); } +void prefetcht2(const Address& addr) { opMR(addr, Reg32(3), T_0F, 0x18); } +void prefetchw(const Address& addr) { opMR(addr, Reg32(1), T_0F, 0x0D); } +void prefetchwt1(const Address& addr) { opMR(addr, Reg32(2), T_0F, 0x0D); } +void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } +void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, T_0F38, T_66); } +void pshufd(const Mmx& mmx, const Operand& op, uint8_t imm8) { opMMX(mmx, op, 0x70, T_0F, T_66, imm8); } +void pshufhw(const Mmx& mmx, const Operand& op, uint8_t imm8) { opMMX(mmx, op, 0x70, T_0F, T_F3, imm8); } +void pshuflw(const Mmx& mmx, const Operand& op, uint8_t imm8) { opMMX(mmx, op, 0x70, T_0F, T_F2, imm8); } +void pshufw(const Mmx& mmx, const Operand& op, uint8_t imm8) { opMMX(mmx, op, 0x70, T_0F, T_NONE, imm8); } +void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, T_0F38, T_66); } +void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, T_0F38, T_66); } +void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, T_0F38, T_66); } +void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } +void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } +void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } +void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } +void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } +void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } +void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } +void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } +void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } +void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } +void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } +void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } +void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } +void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } +void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } +void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } +void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } +void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } +void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } +void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } +void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } +void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } +void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } +void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } +void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } +void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } +void ptest(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F38|T_YMM, 0x17, isXMM_XMMorMEM, NONE); } +void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } +void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } +void punpckhqdq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x6D, isXMM_XMMorMEM); } +void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } +void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } +void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } +void punpcklqdq(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x6C, isXMM_XMMorMEM); } +void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } +void pushf() { db(0x9C); } +void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } +void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } +void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } +void rcl(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2, &d); } +void rcl(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 2, &d); } +void rcpps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x53, isXMM_XMMorMEM); } +void rcpss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x53, isXMM_XMMorMEM); } +void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } +void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } +void rcr(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3, &d); } +void rcr(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 3, &d); } +void rdmsr() { db(0x0F); db(0x32); } +void rdpmc() { db(0x0F); db(0x33); } +void rdrand(const Reg& r) { if (r.isBit(8)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) opRR(Reg(6, Operand::REG, r.getBit()), r, T_0F, 0xC7); } +void rdseed(const Reg& r) { if (r.isBit(8)) XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) opRR(Reg(7, Operand::REG, r.getBit()), r, T_0F, 0xC7); } +void rdtsc() { db(0x0F); db(0x31); } +void rdtscp() { db(0x0F); db(0x01); db(0xF9); } +void rep() { db(0xF3); } +void repe() { db(0xF3); } +void repne() { db(0xF2); } +void repnz() { db(0xF2); } +void repz() { db(0xF3); } +void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } +void retf(int imm = 0) { if (imm) { db(0xCA); dw(imm); } else { db(0xCB); } } +void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 8); } +void rol(const Operand& op, int imm) { opShift(op, imm, 8); } +void rol(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 8, &d); } +void rol(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 8, &d); } +void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 9); } +void ror(const Operand& op, int imm) { opShift(op, imm, 9); } +void ror(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 9, &d); } +void ror(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 9, &d); } +void rorx(const Reg32e& r, const Operand& op, uint8_t imm) { opRRO(r, Reg32e(0, r.getBit()), op, T_0F3A|T_F2|T_APX, 0xF0, imm); } +void roundpd(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A|T_YMM, 0x09, isXMM_XMMorMEM, imm); } +void roundps(const Xmm& xmm, const Operand& op, uint8_t imm) { opSSE(xmm, op, T_66|T_0F3A|T_YMM, 0x08, isXMM_XMMorMEM, imm); } +void roundsd(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x0B, isXMM_XMMorMEM, static_cast(imm)); } +void roundss(const Xmm& xmm, const Operand& op, int imm) { opSSE(xmm, op, T_66 | T_0F3A, 0x0A, isXMM_XMMorMEM, static_cast(imm)); } +void rsqrtps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x52, isXMM_XMMorMEM); } +void rsqrtss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x52, isXMM_XMMorMEM); } +void sahf() { db(0x9E); } +void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 12); } +void sal(const Operand& op, int imm) { opShift(op, imm, 12); } +void sal(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 12, &d); } +void sal(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 12, &d); } +void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 15); } +void sar(const Operand& op, int imm) { opShift(op, imm, 15); } +void sar(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 15, &d); } +void sar(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 15, &d); } +void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opRRO(r1, r2, op, T_APX|T_F3|T_0F38, 0xf7); } +void sbb(const Operand& op, uint32_t imm) { opOI(op, imm, 0x18, 3); } +void sbb(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x18); } +void sbb(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NONE, 3); } +void sbb(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NONE, 0x18); } +void scasb() { db(0xAE); } +void scasd() { db(0xAF); } +void scasw() { db(0x66); db(0xAF); } +void serialize() { db(0x0F); db(0x01); db(0xE8); } +void seta(const Operand& op) { opSetCC(op, 7); }//-V524 +void setae(const Operand& op) { opSetCC(op, 3); }//-V524 +void setb(const Operand& op) { opSetCC(op, 2); }//-V524 +void setbe(const Operand& op) { opSetCC(op, 6); }//-V524 +void setc(const Operand& op) { opSetCC(op, 2); }//-V524 +void sete(const Operand& op) { opSetCC(op, 4); }//-V524 +void setg(const Operand& op) { opSetCC(op, 15); }//-V524 +void setge(const Operand& op) { opSetCC(op, 13); }//-V524 +void setl(const Operand& op) { opSetCC(op, 12); }//-V524 +void setle(const Operand& op) { opSetCC(op, 14); }//-V524 +void setna(const Operand& op) { opSetCC(op, 6); }//-V524 +void setnae(const Operand& op) { opSetCC(op, 2); }//-V524 +void setnb(const Operand& op) { opSetCC(op, 3); }//-V524 +void setnbe(const Operand& op) { opSetCC(op, 7); }//-V524 +void setnc(const Operand& op) { opSetCC(op, 3); }//-V524 +void setne(const Operand& op) { opSetCC(op, 5); }//-V524 +void setng(const Operand& op) { opSetCC(op, 14); }//-V524 +void setnge(const Operand& op) { opSetCC(op, 12); }//-V524 +void setnl(const Operand& op) { opSetCC(op, 13); }//-V524 +void setnle(const Operand& op) { opSetCC(op, 15); }//-V524 +void setno(const Operand& op) { opSetCC(op, 1); }//-V524 +void setnp(const Operand& op) { opSetCC(op, 11); }//-V524 +void setns(const Operand& op) { opSetCC(op, 9); }//-V524 +void setnz(const Operand& op) { opSetCC(op, 5); }//-V524 +void seto(const Operand& op) { opSetCC(op, 0); }//-V524 +void setp(const Operand& op) { opSetCC(op, 10); }//-V524 +void setpe(const Operand& op) { opSetCC(op, 10); }//-V524 +void setpo(const Operand& op) { opSetCC(op, 11); }//-V524 +void sets(const Operand& op) { opSetCC(op, 8); }//-V524 +void setz(const Operand& op) { opSetCC(op, 4); }//-V524 +void sfence() { db(0x0F); db(0xAE); db(0xF8); } +void sha1msg1(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xC9, T_MUST_EVEX, 0xD9); } +void sha1msg2(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xCA, T_MUST_EVEX, 0xDA); } +void sha1nexte(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xC8, T_MUST_EVEX, 0xD8); } +void sha1rnds4(const Xmm& x, const Operand& op, uint8_t imm) { opSSE_APX(x, op, T_0F3A, 0xCC, T_MUST_EVEX, 0xD4, imm); } +void sha256msg1(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xCC, T_MUST_EVEX, 0xDC); } +void sha256msg2(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xCD, T_MUST_EVEX, 0xDD); } +void sha256rnds2(const Xmm& x, const Operand& op) { opSSE_APX(x, op, T_0F38, 0xCB, T_MUST_EVEX, 0xDB); } +void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 12); } +void shl(const Operand& op, int imm) { opShift(op, imm, 12); } +void shl(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 12, &d); } +void shl(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 12, &d); } +void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(Reg(), op, reg, 0, 0xA4, 0x24, &_cl); } +void shld(const Operand& op, const Reg& reg, uint8_t imm) { opShxd(Reg(), op, reg, imm, 0xA4, 0x24); } +void shld(const Reg& d, const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(d, op, reg, 0, 0xA4, 0x24, &_cl); } +void shld(const Reg& d, const Operand& op, const Reg& reg, uint8_t imm) { opShxd(d, op, reg, imm, 0xA4, 0x24); } +void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opRRO(r1, r2, op, T_APX|T_66|T_0F38, 0xf7); } +void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 13); } +void shr(const Operand& op, int imm) { opShift(op, imm, 13); } +void shr(const Reg& d, const Operand& op, const Reg8& _cl) { opShift(op, _cl, 13, &d); } +void shr(const Reg& d, const Operand& op, int imm) { opShift(op, imm, 13, &d); } +void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(Reg(), op, reg, 0, 0xAC, 0x2C, &_cl); } +void shrd(const Operand& op, const Reg& reg, uint8_t imm) { opShxd(Reg(), op, reg, imm, 0xAC, 0x2C); } +void shrd(const Reg& d, const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(d, op, reg, 0, 0xAC, 0x2C, &_cl); } +void shrd(const Reg& d, const Operand& op, const Reg& reg, uint8_t imm) { opShxd(d, op, reg, imm, 0xAC, 0x2C); } +void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opRRO(r1, r2, op, T_APX|T_F2|T_0F38, 0xf7); } +void shufpd(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F | T_66, 0xC6, isXMM_XMMorMEM, imm8); } +void shufps(const Xmm& xmm, const Operand& op, uint8_t imm8) { opSSE(xmm, op, T_0F, 0xC6, isXMM_XMMorMEM, imm8); } +void sqrtpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x51, isXMM_XMMorMEM); } +void sqrtps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x51, isXMM_XMMorMEM); } +void sqrtsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x51, isXMM_XMMorMEM); } +void sqrtss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x51, isXMM_XMMorMEM); } +void stac() { db(0x0F); db(0x01); db(0xCB); } +void stc() { db(0xF9); } +void std() { db(0xFD); } +void sti() { db(0xFB); } +void stmxcsr(const Address& addr) { opMR(addr, Reg32(3), T_0F, 0xAE); } +void stosb() { db(0xAA); } +void stosd() { db(0xAB); } +void stosw() { db(0x66); db(0xAB); } +void sub(const Operand& op, uint32_t imm) { opOI(op, imm, 0x28, 5); } +void sub(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x28); } +void sub(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NF|T_CODE1_IF1, 5); } +void sub(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NF|T_CODE1_IF1, 0x28); } +void subpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x5C, isXMM_XMMorMEM); } +void subps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x5C, isXMM_XMMorMEM); } +void subsd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F2, 0x5C, isXMM_XMMorMEM); } +void subss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_F3, 0x5C, isXMM_XMMorMEM); } +void sysenter() { db(0x0F); db(0x34); } +void sysexit() { db(0x0F); db(0x35); } +void tpause(const Reg32& r) { int idx = r.getIdx(); if (idx > 7) XBYAK_THROW(ERR_BAD_PARAMETER) db(0x66); db(0x0F); db(0xAE); setModRM(3, 6, idx); } +void tzcnt(const Reg®, const Operand& op) { if (opROO(Reg(), op, reg, T_APX|T_NF, 0xF4)) return; opCnt(reg, op, 0xBC); } +void ucomisd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_66|T_0F, 0x2E, isXMM_XMMorMEM); } +void ucomiss(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x2E, isXMM_XMMorMEM); } +void ud2() { db(0x0F); db(0x0B); } +void umonitor(const Reg& r) { int idx = r.getIdx(); if (idx > 7) XBYAK_THROW(ERR_BAD_PARAMETER) int bit = r.getBit(); if (BIT != bit) { if ((BIT == 32 && bit == 16) || (BIT == 64 && bit == 32)) { db(0x67); } else { XBYAK_THROW(ERR_BAD_SIZE_OF_REGISTER) } } db(0xF3); db(0x0F); db(0xAE); setModRM(3, 6, idx); } +void umwait(const Reg32& r) { int idx = r.getIdx(); if (idx > 7) XBYAK_THROW(ERR_BAD_PARAMETER) db(0xF2); db(0x0F); db(0xAE); setModRM(3, 6, idx); } +void unpckhpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x15, isXMM_XMMorMEM); } +void unpckhps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x15, isXMM_XMMorMEM); } +void unpcklpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x14, isXMM_XMMorMEM); } +void unpcklps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x14, isXMM_XMMorMEM); } +void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } +void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } +void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x58); } +void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x58); } +void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F|T_YMM, 0xD0); } +void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2|T_0F|T_YMM, 0xD0); } +void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F38|T_YMM|T_EVEX, 0xDE); } +void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F38|T_YMM|T_EVEX, 0xDF); } +void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F38|T_YMM|T_EVEX, 0xDC); } +void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F38|T_YMM|T_EVEX, 0xDD); } +void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_W0, 0xDB); } +void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A, 0xDF, imm); } +void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } +void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } +void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } +void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } +void vbcstnebf162ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_F3|T_0F38|T_W0|T_YMM|T_B16, 0xB1); } +void vbcstnesh2ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_66|T_0F38|T_W0|T_YMM|T_B16, 0xB1); } +void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x0D, imm); } +void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x0C, imm); } +void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } +void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } +void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } +void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } +void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } +void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_W0|T_YMM|T_EVEX, 0x18); } +void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } +void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } +void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } +void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } +void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } +void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } +void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } +void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } +void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } +void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } +void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } +void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } +void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } +void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } +void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } +void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } +void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } +void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } +void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } +void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } +void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } +void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } +void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } +void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } +void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } +void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } +void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } +void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } +void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } +void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } +void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } +void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } +void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } +void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } +void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } +void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } +void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } +void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } +void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } +void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } +void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } +void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } +void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } +void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } +void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } +void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } +void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } +void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } +void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } +void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } +void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } +void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } +void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } +void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } +void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } +void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } +void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } +void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } +void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } +void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } +void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } +void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } +void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } +void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } +void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } +void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } +void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } +void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } +void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } +void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } +void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } +void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } +void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } +void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } +void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } +void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } +void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } +void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } +void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } +void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } +void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } +void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } +void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } +void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } +void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } +void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } +void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } +void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } +void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } +void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } +void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } +void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } +void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } +void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } +void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } +void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } +void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } +void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } +void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } +void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } +void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } +void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } +void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } +void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } +void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } +void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } +void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } +void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } +void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } +void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } +void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } +void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } +void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0xC2, imm); } +void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_0F|T_YMM, 0xC2, imm); } +void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_F2|T_0F, 0xC2, imm); } +void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F, 0xC2, imm); } +void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } +void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } +void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } +void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } +void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } +void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } +void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } +void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } +void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } +void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } +void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } +void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } +void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } +void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } +void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } +void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } +void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_66|T_0F|T_EW1|T_EVEX|T_SAE_X, 0x2F); } +void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_0F|T_EW0|T_EVEX|T_SAE_X, 0x2F); } +void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } +void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_EW0|T_YMM|T_EVEX|T_ER_Z|T_B32, 0x5B); } +void vcvtneebf162ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_F3|T_0F38|T_W0|T_YMM, 0xB0); } +void vcvtneeph2ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_66|T_0F38|T_W0|T_YMM, 0xB0); } +void vcvtneobf162ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_F2|T_0F38|T_W0|T_YMM, 0xB0); } +void vcvtneoph2ps(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38|T_W0|T_YMM, 0xB0); } +void vcvtneps2bf16(const Xmm& x, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opCvt2(x, op, T_F3|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_B32|orEvexIf(encoding), 0x72); } +void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } +void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } +void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_ER_Z|T_B32, 0x5B); } +void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } +void vcvtps2ph(const Operand& op, const Xmm& x, uint8_t imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y | T_M_K, 0x1D, imm); } +void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } +void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_F2|T_0F|T_EW1|T_EVEX|T_ER_X, 0x5A); } +void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_F3|T_0F|T_EW0|T_EVEX|T_SAE_X, 0x5A); } +void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } +void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3|T_0F|T_EW0|T_YMM|T_EVEX|T_SAE_Z|T_B32, 0x5B); } +void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } +void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } +void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } +void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } +void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x5E); } +void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x5E); } +void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0, 0x41, imm); } +void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x40, imm); } +void vextractf128(const Operand& op, const Ymm& y, uint8_t imm) { if (!(op.isXMEM() && y.isYMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } +void vextracti128(const Operand& op, const Ymm& y, uint8_t imm) { if (!(op.isXMEM() && y.isYMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } +void vextractps(const Operand& op, const Xmm& x, uint8_t imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } +void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x98); } +void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x98); } +void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0x99); } +void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0x99); } +void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xA8); } +void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xA8); } +void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xA9); } +void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xA9); } +void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xB8); } +void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xB8); } +void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xB9); } +void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xB9); } +void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x96); } +void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x96); } +void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xA6); } +void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xA6); } +void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xB6); } +void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xB6); } +void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x9A); } +void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x9A); } +void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0x9B); } +void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0x9B); } +void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xAA); } +void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xAA); } +void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xAB); } +void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xAB); } +void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xBA); } +void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xBA); } +void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xBB); } +void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xBB); } +void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x97); } +void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x97); } +void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xA7); } +void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xA7); } +void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xB7); } +void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xB7); } +void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x9C); } +void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x9C); } +void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0x9D); } +void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0x9D); } +void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xAC); } +void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xAC); } +void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xAD); } +void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xAD); } +void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xBC); } +void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xBC); } +void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xBD); } +void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xBD); } +void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x9E); } +void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x9E); } +void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0x9F); } +void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0x9F); } +void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xAE); } +void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xAE); } +void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xAF); } +void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xAF); } +void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0xBE); } +void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0xBE); } +void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_W1|T_EW1|T_EVEX|T_ER_X, 0xBF); } +void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_W0|T_EW0|T_EVEX|T_ER_X, 0xBF); } +void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } +void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } +void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } +void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } +void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W1|T_EW1|T_YMM|T_EVEX|T_SAE_Z|T_B64, 0xCF, imm); } +void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W1|T_EW1|T_YMM|T_EVEX|T_SAE_Z|T_B64, 0xCE, imm); } +void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_SAE_Z, 0xCF); } +void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F|T_YMM, 0x7C); } +void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2|T_0F|T_YMM, 0x7C); } +void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66|T_0F|T_YMM, 0x7D); } +void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2|T_0F|T_YMM, 0x7D); } +void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } +void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } +void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_W0|T_EW0|T_EVEX, 0x21, imm); } +void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } +void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } +void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } +void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } +void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } +void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } +void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } +void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } +void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } +void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x5F); } +void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x5F); } +void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } +void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } +void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x5D); } +void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x5D); } +void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_M_K, 0x29); } +void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX, 0x28); } +void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F|T_EW0|T_YMM|T_EVEX|T_M_K, 0x29); } +void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_EW0|T_YMM|T_EVEX, 0x28); } +void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } +void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } +void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP|T_F2|T_0F|T_EW1|T_YMM|T_EVEX|T_ER_X|T_ER_Y|T_ER_Z, 0x12); } +void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66|T_0F|T_YMM, 0x7F); } +void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_YMM, 0x6F); } +void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3|T_0F|T_YMM, 0x7F); } +void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3|T_0F|T_YMM, 0x6F); } +void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } +void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_66|T_0F|T_EW1|T_EVEX, 0x17); } +void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, op1, op2, T_N8|T_66|T_0F|T_EW1|T_EVEX, 0x16); } +void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_0F|T_EW0|T_EVEX, 0x17); } +void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, op1, op2, T_N8|T_0F|T_EW0|T_EVEX, 0x16); } +void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } +void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_66|T_0F|T_EW1|T_EVEX, 0x13); } +void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, op1, op2, T_N8|T_66|T_0F|T_EW1|T_EVEX, 0x12); } +void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_0F|T_EW0|T_EVEX, 0x13); } +void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, op1, op2, T_N8|T_0F|T_EW0|T_EVEX, 0x12); } +void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } +void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } +void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } +void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } +void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } +void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } +void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } +void vmovq(const Xmm& x, const Address& addr) { uint64_t type; uint8_t code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } +void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } +void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_F2|T_0F|T_EW1|T_EVEX | T_M_K, 0x11); } +void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8|T_F2|T_0F|T_EW1|T_EVEX, 0x10); } +void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x1, x2, op, T_N8|T_F2|T_0F|T_EW1|T_EVEX, 0x10); } +void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3|T_0F|T_EW0|T_YMM|T_EVEX, 0x16); } +void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3|T_0F|T_EW0|T_YMM|T_EVEX, 0x12); } +void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4|T_F3|T_0F|T_EW0|T_EVEX | T_M_K, 0x11); } +void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4|T_F3|T_0F|T_EW0|T_EVEX, 0x10); } +void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x1, x2, op, T_N4|T_F3|T_0F|T_EW0|T_EVEX, 0x10); } +void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_M_K, 0x11); } +void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX, 0x10); } +void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F|T_EW0|T_YMM|T_EVEX|T_M_K, 0x11); } +void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_EW0|T_YMM|T_EVEX, 0x10); } +void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x42, imm); } +void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } +void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } +void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x59); } +void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x59); } +void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } +void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } +void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_YMM|T_EVEX, 0x1C); } +void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x1E); } +void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_YMM|T_EVEX, 0x1D); } +void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x6B); } +void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x63); } +void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x2B); } +void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x67); } +void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xFC); } +void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0xFE); } +void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0xD4); } +void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xEC); } +void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xED); } +void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xDC); } +void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xDD); } +void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xFD); } +void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_YMM|T_EVEX, 0x0F, imm); } +void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0xDB); } +void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0xDF); } +void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE0); } +void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE3); } +void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x02, imm); } +void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } +void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM, 0x0E, imm); } +void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(x, op, T_N1|T_66|T_0F38|T_W0|T_YMM|T_EVEX, 0x78); } +void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_W0|T_YMM|T_EVEX, 0x58); } +void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(x, op, T_N8|T_66|T_0F38|T_W0|T_EW1|T_YMM|T_EVEX, 0x59); } +void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_XM_IMM(x, op, T_N2|T_66|T_0F38|T_W0|T_YMM|T_EVEX, 0x79); } +void vpclmulhqhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { vpclmulqdq(x1, x2, op, 0x11); } +void vpclmulhqlqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { vpclmulqdq(x1, x2, op, 0x01); } +void vpclmullqhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { vpclmulqdq(x1, x2, op, 0x10); } +void vpclmullqlqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { vpclmulqdq(x1, x2, op, 0x00); } +void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_YMM|T_EVEX, 0x44, imm); } +void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x74); } +void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x76); } +void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x29); } +void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x75); } +void vpcmpestri(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A, 0x61, imm); } +void vpcmpestrm(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A, 0x60, imm); } +void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x64); } +void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x66); } +void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x37); } +void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0x65); } +void vpcmpistri(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A, 0x63, imm); } +void vpcmpistrm(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A, 0x62, imm); } +void vpdpbssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_0F38|T_W0|T_YMM, 0x50); } +void vpdpbssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_0F38|T_W0|T_YMM, 0x51); } +void vpdpbsud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_W0|T_YMM, 0x50); } +void vpdpbsuds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_W0|T_YMM, 0x51); } +void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_B32, 0x50, encoding); } +void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_B32, 0x51, encoding); } +void vpdpbuud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F38|T_W0|T_YMM, 0x50); } +void vpdpbuuds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F38|T_W0|T_YMM, 0x51); } +void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_B32, 0x52, encoding); } +void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_B32, 0x53, encoding); } +void vpdpwsud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_W0|T_YMM, 0xD2); } +void vpdpwsuds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_W0|T_YMM, 0xD3); } +void vpdpwusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_YMM, 0xD2); } +void vpdpwusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_YMM, 0xD3); } +void vpdpwuud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F38|T_W0|T_YMM, 0xD2); } +void vpdpwuuds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F38|T_W0|T_YMM, 0xD3); } +void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } +void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } +void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x36); } +void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW1|T_YMM|T_EVEX|T_B64, 0x0D); } +void vpermilpd(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A|T_EW1|T_YMM|T_EVEX|T_B64, 0x05, imm); } +void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x0C); } +void vpermilps(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A|T_EW0|T_YMM|T_EVEX|T_B32, 0x04, imm); } +void vpermpd(const Ymm& y, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(y, op, T_66|T_0F3A|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x01, imm); } +void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x16); } +void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x16); } +void vpermq(const Ymm& y, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(y, op, T_66|T_0F3A|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x00, imm); } +void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66|T_0F38|T_W0|T_EW1|T_YMM|T_EVEX|T_B64, 0x36); } +void vpextrb(const Operand& op, const Xmm& x, uint8_t imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } +void vpextrd(const Operand& op, const Xmm& x, uint8_t imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } +void vpextrq(const Operand& op, const Xmm& x, uint8_t imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } +void vpextrw(const Operand& op, const Xmm& x, uint8_t imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) XBYAK_THROW(ERR_BAD_COMBINATION) if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } +void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } +void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } +void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } +void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } +void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x02); } +void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x03); } +void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x01); } +void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38, 0x41); } +void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x06); } +void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x07); } +void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x05); } +void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } +void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } +void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } +void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } +void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_B64, 0xB5, encoding); } +void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opEncoding(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_B64, 0xB4, encoding); } +void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x04); } +void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xF5); } +void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } +void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } +void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } +void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } +void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x3C); } +void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x3D); } +void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xEE); } +void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xDE); } +void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x3F); } +void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x3E); } +void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x38); } +void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x39); } +void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xEA); } +void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xDA); } +void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x3B); } +void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x3A); } +void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } +void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x21); } +void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x22); } +void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x20); } +void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_EW0|T_YMM|T_EVEX, 0x25); } +void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x23); } +void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x24); } +void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x31); } +void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x32); } +void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x30); } +void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_EW0|T_YMM|T_EVEX, 0x35); } +void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x33); } +void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_N_VL|T_66|T_0F38|T_YMM|T_EVEX, 0x34); } +void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_EVEX|T_B64, 0x28); } +void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x0B); } +void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE4); } +void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE5); } +void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_EVEX|T_B32, 0x40); } +void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xD5); } +void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0xF4); } +void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0xEB); } +void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xF6); } +void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM|T_EVEX, 0x00); } +void vpshufd(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x70, imm); } +void vpshufhw(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_F3|T_0F|T_YMM|T_EVEX, 0x70, imm); } +void vpshuflw(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_F2|T_0F|T_YMM|T_EVEX, 0x70, imm); } +void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x08); } +void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x0A); } +void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_YMM, 0x09); } +void vpslld(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32|T_MEM_EVEX, 0x72, imm); } +void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW0|T_YMM|T_EVEX, 0xF2); } +void vpslldq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66|T_0F|T_YMM|T_EVEX|T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64|T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW1|T_YMM|T_EVEX, 0xF3); } +void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x47); } +void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x47); } +void vpsllw(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66|T_0F|T_YMM|T_EVEX|T_MEM_EVEX, 0x71, imm); } +void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_YMM|T_EVEX, 0xF1); } +void vpsrad(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32|T_MEM_EVEX, 0x72, imm); } +void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW0|T_YMM|T_EVEX, 0xE2); } +void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x46); } +void vpsraw(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66|T_0F|T_YMM|T_EVEX|T_MEM_EVEX, 0x71, imm); } +void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_YMM|T_EVEX, 0xE1); } +void vpsrld(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32|T_MEM_EVEX, 0x72, imm); } +void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW0|T_YMM|T_EVEX, 0xD2); } +void vpsrldq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66|T_0F|T_YMM|T_EVEX|T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64|T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW1|T_YMM|T_EVEX, 0xD3); } +void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_YMM|T_EVEX|T_B32, 0x45); } +void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W1|T_EW1|T_YMM|T_EVEX|T_B64, 0x45); } +void vpsrlw(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66|T_0F|T_YMM|T_EVEX|T_MEM_EVEX, 0x71, imm); } +void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_YMM|T_EVEX, 0xD1); } +void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xF8); } +void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0xFA); } +void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0xFB); } +void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE8); } +void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xE9); } +void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xD8); } +void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xD9); } +void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0xF9); } +void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_YMM, 0x17); } +void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x68); } +void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x6A); } +void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0x6D); } +void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x69); } +void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x60); } +void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x62); } +void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0x6C); } +void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM|T_EVEX, 0x61); } +void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_YMM, 0xEF); } +void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_YMM, 0x53); } +void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F, 0x53); } +void vroundpd(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A|T_YMM, 0x09, imm); } +void vroundps(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66|T_0F3A|T_YMM, 0x08, imm); } +void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0, 0x0B, imm); } +void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0, 0x0A, imm); } +void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_YMM, 0x52); } +void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F, 0x52); } +void vsha512msg1(const Ymm& y, const Xmm& x) { if (!(y.isYMM() && x.isXMM())) XBYAK_THROW(ERR_BAD_PARAMETER) opVex(y, 0, x, T_F2 | T_0F38 | T_W0 | T_YMM, 0xCC); } +void vsha512msg2(const Ymm& y1, const Ymm& y2) { if (!(y1.isYMM() && y2.isYMM())) XBYAK_THROW(ERR_BAD_PARAMETER) opVex(y1, 0, y2, T_F2 | T_0F38 | T_W0 | T_YMM, 0xCD); } +void vsha512rnds2(const Ymm& y1, const Ymm& y2, const Xmm& x) { if (!(y1.isYMM() && y2.isYMM() && x.isXMM())) XBYAK_THROW(ERR_BAD_PARAMETER) opVex(y1, &y2, x, T_F2 | T_0F38 | T_W0 | T_YMM, 0xCB); } +void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0xC6, imm); } +void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0xC6, imm); } +void vsm3msg1(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F38|T_W0|T_EW0|T_EVEX, 0xDA); } +void vsm3msg2(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_W0|T_EW0|T_EVEX, 0xDA); } +void vsm3rnds2(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_W0|T_EW0|T_EVEX, 0xDE, imm); } +void vsm4key4(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_W0|T_EW0|T_EVEX, 0xDA); } +void vsm4rnds4(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_0F38|T_W0|T_EW0|T_EVEX, 0xDA); } +void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_ER_Z|T_B64, 0x51); } +void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F|T_EW0|T_YMM|T_EVEX|T_ER_Z|T_B32, 0x51); } +void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_F2|T_0F|T_EW1|T_EVEX|T_ER_X, 0x51); } +void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_F3|T_0F|T_EW0|T_EVEX|T_ER_X, 0x51); } +void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } +void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } +void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } +void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_X | T_N8, 0x5C); } +void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_X | T_N4, 0x5C); } +void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_YMM, 0x0F); } +void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66|T_0F38|T_YMM, 0x0E); } +void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8|T_66|T_0F|T_EW1|T_EVEX|T_SAE_X, 0x2E); } +void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4|T_0F|T_EW0|T_EVEX|T_SAE_X, 0x2E); } +void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0x15); } +void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x15); } +void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_EVEX|T_B64, 0x14); } +void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F|T_EW0|T_YMM|T_EVEX|T_B32, 0x14); } +void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } +void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } +void vzeroall() { db(0xC5); db(0xFC); db(0x77); } +void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } +void wait() { db(0x9B); } +void wbinvd() { db(0x0F); db(0x09); } +void wrmsr() { db(0x0F); db(0x30); } +void xabort(uint8_t imm) { db(0xC6); db(0xF8); db(imm); } +void xadd(const Operand& op, const Reg& reg) { opRO(reg, op, T_0F, 0xC0 | (reg.isBit(8) ? 0 : 1), op.getBit() == reg.getBit()); } +void xbegin(uint32_t rel) { db(0xC7); db(0xF8); dd(rel); } +void xend() { db(0x0F); db(0x01); db(0xD5); } +void xgetbv() { db(0x0F); db(0x01); db(0xD0); } +void xlatb() { db(0xD7); } +void xor_(const Operand& op, uint32_t imm) { opOI(op, imm, 0x30, 6); } +void xor_(const Operand& op1, const Operand& op2) { opRO_MR(op1, op2, 0x30); } +void xor_(const Reg& d, const Operand& op, uint32_t imm) { opROI(d, op, imm, T_NF|T_CODE1_IF1, 6); } +void xor_(const Reg& d, const Operand& op1, const Operand& op2) { opROO(d, op1, op2, T_NF|T_CODE1_IF1, 0x30); } +void xorpd(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F | T_66, 0x57, isXMM_XMMorMEM); } +void xorps(const Xmm& xmm, const Operand& op) { opSSE(xmm, op, T_0F, 0x57, isXMM_XMMorMEM); } +#ifdef XBYAK_ENABLE_OMITTED_OPERAND +void vblendpd(const Xmm& x, const Operand& op, uint8_t imm) { vblendpd(x, x, op, imm); } +void vblendps(const Xmm& x, const Operand& op, uint8_t imm) { vblendps(x, x, op, imm); } +void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } +void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } +void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } +void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } +void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } +void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } +void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } +void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } +void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } +void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } +void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } +void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } +void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } +void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } +void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } +void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } +void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } +void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } +void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } +void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } +void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } +void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } +void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } +void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } +void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } +void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } +void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } +void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } +void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } +void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } +void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } +void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } +void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } +void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } +void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } +void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } +void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } +void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } +void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } +void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } +void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } +void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } +void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } +void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } +void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } +void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } +void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } +void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } +void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } +void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } +void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } +void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } +void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } +void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } +void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } +void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } +void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } +void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } +void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } +void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } +void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } +void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } +void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } +void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } +void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } +void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } +void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } +void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } +void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } +void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } +void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } +void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } +void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } +void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } +void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } +void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } +void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } +void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } +void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } +void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } +void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } +void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } +void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } +void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } +void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } +void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } +void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } +void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } +void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } +void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } +void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } +void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } +void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } +void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } +void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } +void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } +void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } +void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } +void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } +void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } +void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } +void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } +void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } +void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } +void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } +void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } +void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } +void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } +void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } +void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } +void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } +void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } +void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } +void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } +void vcmppd(const Xmm& x, const Operand& op, uint8_t imm) { vcmppd(x, x, op, imm); } +void vcmpps(const Xmm& x, const Operand& op, uint8_t imm) { vcmpps(x, x, op, imm); } +void vcmpsd(const Xmm& x, const Operand& op, uint8_t imm) { vcmpsd(x, x, op, imm); } +void vcmpss(const Xmm& x, const Operand& op, uint8_t imm) { vcmpss(x, x, op, imm); } +void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } +void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } +void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } +void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } +void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } +void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } +void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } +void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } +void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } +void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } +void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } +void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } +void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } +void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } +void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } +void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } +void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } +void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } +void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } +void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } +void vdppd(const Xmm& x, const Operand& op, uint8_t imm) { vdppd(x, x, op, imm); } +void vdpps(const Xmm& x, const Operand& op, uint8_t imm) { vdpps(x, x, op, imm); } +void vinsertps(const Xmm& x, const Operand& op, uint8_t imm) { vinsertps(x, x, op, imm); } +void vmpsadbw(const Xmm& x, const Operand& op, uint8_t imm) { vmpsadbw(x, x, op, imm); } +void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } +void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } +void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } +void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } +void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } +void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } +void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } +void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } +void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } +void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } +void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } +void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } +void vpalignr(const Xmm& x, const Operand& op, uint8_t imm) { vpalignr(x, x, op, imm); } +void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } +void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } +void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } +void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } +void vpblendd(const Xmm& x, const Operand& op, uint8_t imm) { vpblendd(x, x, op, imm); } +void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } +void vpblendw(const Xmm& x, const Operand& op, uint8_t imm) { vpblendw(x, x, op, imm); } +void vpclmulqdq(const Xmm& x, const Operand& op, uint8_t imm) { vpclmulqdq(x, x, op, imm); } +void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } +void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } +void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } +void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } +void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } +void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } +void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } +void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } +void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } +void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } +void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } +void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } +void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } +void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } +void vpinsrb(const Xmm& x, const Operand& op, uint8_t imm) { vpinsrb(x, x, op, imm); } +void vpinsrd(const Xmm& x, const Operand& op, uint8_t imm) { vpinsrd(x, x, op, imm); } +void vpinsrq(const Xmm& x, const Operand& op, uint8_t imm) { vpinsrq(x, x, op, imm); } +void vpinsrw(const Xmm& x, const Operand& op, uint8_t imm) { vpinsrw(x, x, op, imm); } +void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } +void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } +void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } +void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } +void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } +void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } +void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } +void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } +void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } +void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } +void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } +void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } +void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } +void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } +void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } +void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } +void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } +void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } +void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } +void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } +void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } +void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } +void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } +void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } +void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } +void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } +void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } +void vpslld(const Xmm& x, uint8_t imm) { vpslld(x, x, imm); } +void vpslldq(const Xmm& x, uint8_t imm) { vpslldq(x, x, imm); } +void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } +void vpsllq(const Xmm& x, uint8_t imm) { vpsllq(x, x, imm); } +void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } +void vpsllw(const Xmm& x, uint8_t imm) { vpsllw(x, x, imm); } +void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } +void vpsrad(const Xmm& x, uint8_t imm) { vpsrad(x, x, imm); } +void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } +void vpsraw(const Xmm& x, uint8_t imm) { vpsraw(x, x, imm); } +void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } +void vpsrld(const Xmm& x, uint8_t imm) { vpsrld(x, x, imm); } +void vpsrldq(const Xmm& x, uint8_t imm) { vpsrldq(x, x, imm); } +void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } +void vpsrlq(const Xmm& x, uint8_t imm) { vpsrlq(x, x, imm); } +void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } +void vpsrlw(const Xmm& x, uint8_t imm) { vpsrlw(x, x, imm); } +void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } +void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } +void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } +void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } +void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } +void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } +void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } +void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } +void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } +void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } +void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } +void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } +void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } +void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } +void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } +void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } +void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } +void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } +void vroundsd(const Xmm& x, const Operand& op, uint8_t imm) { vroundsd(x, x, op, imm); } +void vroundss(const Xmm& x, const Operand& op, uint8_t imm) { vroundss(x, x, op, imm); } +void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } +void vshufpd(const Xmm& x, const Operand& op, uint8_t imm) { vshufpd(x, x, op, imm); } +void vshufps(const Xmm& x, const Operand& op, uint8_t imm) { vshufps(x, x, op, imm); } +void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } +void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } +void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } +void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } +void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } +void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } +#endif +#ifdef XBYAK64 +void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void cdqe() { db(0x48); db(0x98); } +void cqo() { db(0x48); db(0x99); } +void cmpsq() { db(0x48); db(0xA7); } +void popfq() { db(0x9D); } +void pushfq() { db(0x9C); } +void lodsq() { db(0x48); db(0xAD); } +void movsq() { db(0x48); db(0xA5); } +void scasq() { db(0x48); db(0xAF); } +void stosq() { db(0x48); db(0xAB); } +void syscall() { db(0x0F); db(0x05); } +void sysret() { db(0x0F); db(0x07); } +void clui() { db(0xF3); db(0x0F); db(0x01); db(0xEE); } +void stui() { db(0xF3); db(0x0F); db(0x01); db(0xEF); } +void testui() { db(0xF3); db(0x0F); db(0x01); db(0xED); } +void uiret() { db(0xF3); db(0x0F); db(0x01); db(0xEC); } +void cmpxchg16b(const Address& addr) { opMR(addr, Reg64(1), T_0F, 0xC7); } +void fxrstor64(const Address& addr) { opMR(addr, Reg64(1), T_0F, 0xAE); } +void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opRR(mmx, reg, T_0F, 0x7E); } +void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opRR(mmx, reg, T_0F, 0x6E); } +void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) XBYAK_THROW(ERR_BAD_COMBINATION) opRO(reg, op, 0, 0x63); } +void pextrq(const Operand& op, const Xmm& xmm, uint8_t imm) { if (!op.isREG(64) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opSSE(Reg64(xmm.getIdx()), op, T_66 | T_0F3A, 0x16, 0, imm); } +void pinsrq(const Xmm& xmm, const Operand& op, uint8_t imm) { if (!op.isREG(64) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opSSE(Reg64(xmm.getIdx()), op, T_66 | T_0F3A, 0x22, 0, imm); } +void senduipi(const Reg64& r) { opRR(Reg32(6), r.cvt32(), T_F3 | T_0F, 0xC7); } +void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } +void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } +void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } +void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } +void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } +void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } +void jmpabs(uint64_t addr) { db(0xD5); db(0x00); db(0xA1); dq(addr); } +void push2(const Reg64& r1, const Reg64& r2) { opROO(r1, r2, Reg64(6), T_APX|T_ND1|T_W0, 0xFF); } +void push2p(const Reg64& r1, const Reg64& r2) { opROO(r1, r2, Reg64(6), T_APX|T_ND1|T_W1, 0xFF); } +void pop2(const Reg64& r1, const Reg64& r2) { opROO(r1, r2, Reg64(0), T_APX|T_ND1|T_W0, 0x8F); } +void pop2p(const Reg64& r1, const Reg64& r2) { opROO(r1, r2, Reg64(0), T_APX|T_ND1|T_W1, 0x8F); } +void cmpbexadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE6); } +void cmpbxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE2); } +void cmplexadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xEE); } +void cmplxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xEC); } +void cmpnbexadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE7); } +void cmpnbxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE3); } +void cmpnlexadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xEF); } +void cmpnlxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xED); } +void cmpnoxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE1); } +void cmpnpxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xEB); } +void cmpnsxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE9); } +void cmpnzxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE5); } +void cmpoxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE0); } +void cmppxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xEA); } +void cmpsxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE8); } +void cmpzxadd(const Address& addr, const Reg32e& r1, const Reg32e& r2) { opRRO(r1, r2, addr, T_APX|T_66|T_0F38, 0xE4); } +void aesdec128kl(const Xmm& x, const Address& addr) { opSSE_APX(x, addr, T_F3|T_0F38, 0xDD, T_F3|T_MUST_EVEX, 0xDD); } +void aesdec256kl(const Xmm& x, const Address& addr) { opSSE_APX(x, addr, T_F3|T_0F38, 0xDF, T_F3|T_MUST_EVEX, 0xDF); } +void aesdecwide128kl(const Address& addr) { opSSE_APX(xmm1, addr, T_F3|T_0F38, 0xD8, T_F3|T_MUST_EVEX, 0xD8); } +void aesdecwide256kl(const Address& addr) { opSSE_APX(xmm3, addr, T_F3|T_0F38, 0xD8, T_F3|T_MUST_EVEX, 0xD8); } +void aesenc128kl(const Xmm& x, const Address& addr) { opSSE_APX(x, addr, T_F3|T_0F38, 0xDC, T_F3|T_MUST_EVEX, 0xDC); } +void aesenc256kl(const Xmm& x, const Address& addr) { opSSE_APX(x, addr, T_F3|T_0F38, 0xDE, T_F3|T_MUST_EVEX, 0xDE); } +void aesencwide128kl(const Address& addr) { opSSE_APX(xmm0, addr, T_F3|T_0F38, 0xD8, T_F3|T_MUST_EVEX, 0xD8); } +void aesencwide256kl(const Address& addr) { opSSE_APX(xmm2, addr, T_F3|T_0F38, 0xD8, T_F3|T_MUST_EVEX, 0xD8); } +void encodekey128(const Reg32& r1, const Reg32& r2) { opEncodeKey(r1, r2, 0xFA, 0xDA); } +void encodekey256(const Reg32& r1, const Reg32& r2) { opEncodeKey(r1, r2, 0xFB, 0xDB); } +void ldtilecfg(const Address& addr) { if (opROO(Reg(), addr, tmm0, T_APX|T_0F38|T_W0, 0x49)) return; opVex(tmm0, &tmm0, addr, T_0F38|T_W0, 0x49); } +void sttilecfg(const Address& addr) { if (opROO(Reg(), addr, tmm0, T_APX|T_66|T_0F38|T_W0, 0x49)) return; opVex(tmm0, &tmm0, addr, T_66|T_0F38 | T_W0, 0x49); } +void tileloadd(const Tmm& tm, const Address& addr) { opAMX(tm, addr, T_F2|T_0F38|T_W0, 0x4B); } +void tileloaddt1(const Tmm& tm, const Address& addr) { opAMX(tm, addr, T_66|T_0F38|T_W0, 0x4B); } +void tilerelease() { db(0xc4); db(0xe2); db(0x78); db(0x49); db(0xc0); } +void tilestored(const Address& addr, const Tmm& tm) { if (opROO(Reg(), addr, tm, T_APX|T_F3|T_0F38|T_W0, 0x4B)) return; opVex(tm, &tmm0, addr, T_F3|T_0F38|T_W0, 0x4B); } +void tilezero(const Tmm& Tmm) { opVex(Tmm, &tmm0, tmm0, T_F2 | T_0F38 | T_W0, 0x49); } +void tdpbssd(const Tmm& x1, const Tmm& x2, const Tmm& x3) { opVex(x1, &x3, x2, T_F2 | T_0F38 | T_W0, 0x5e); } +void tdpbsud(const Tmm& x1, const Tmm& x2, const Tmm& x3) { opVex(x1, &x3, x2, T_F3 | T_0F38 | T_W0, 0x5e); } +void tdpbusd(const Tmm& x1, const Tmm& x2, const Tmm& x3) { opVex(x1, &x3, x2, T_66 | T_0F38 | T_W0, 0x5e); } +void tdpbuud(const Tmm& x1, const Tmm& x2, const Tmm& x3) { opVex(x1, &x3, x2, T_0F38 | T_W0, 0x5e); } +void tdpfp16ps(const Tmm &x1, const Tmm &x2, const Tmm &x3) { opVex(x1, &x3, x2, T_F2 | T_0F38 | T_W0, 0x5c); } +void tdpbf16ps(const Tmm& x1, const Tmm& x2, const Tmm& x3) { opVex(x1, &x3, x2, T_F3 | T_0F38 | T_W0, 0x5c); } +#else +void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void aaa() { db(0x37); } +void aad() { db(0xD5); db(0x0A); } +void aam() { db(0xD4); db(0x0A); } +void aas() { db(0x3F); } +void daa() { db(0x27); } +void das() { db(0x2F); } +void into() { db(0xCE); } +void popad() { db(0x61); } +void popfd() { db(0x9D); } +void pusha() { db(0x60); } +void pushad() { db(0x60); } +void pushfd() { db(0x9C); } +void popa() { db(0x61); } +void lds(const Reg& reg, const Address& addr) { opLoadSeg(addr, reg, T_NONE, 0xC5); } +void les(const Reg& reg, const Address& addr) { opLoadSeg(addr, reg, T_NONE, 0xC4); } +#endif +#ifndef XBYAK_NO_OP_NAMES +void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } +void and(const Operand& op, uint32_t imm) { and_(op, imm); } +void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } +void or(const Operand& op, uint32_t imm) { or_(op, imm); } +void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } +void xor(const Operand& op, uint32_t imm) { xor_(op, imm); } +void not(const Operand& op) { not_(op); } +#endif +#ifndef XBYAK_DISABLE_AVX512 +void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } +void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } +void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } +void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } +void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } +void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } +void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } +void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } +void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } +void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } +void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } +void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } +void kmovb(const Address& addr, const Opmask& k) { opKmov(k, addr, true, 8); } +void kmovb(const Opmask& k, const Operand& op) { opKmov(k, op, false, 8); } +void kmovb(const Reg32& r, const Opmask& k) { opKmov(k, r, true, 8); } +void kmovd(const Address& addr, const Opmask& k) { opKmov(k, addr, true, 32); } +void kmovd(const Opmask& k, const Operand& op) { opKmov(k, op, false, 32); } +void kmovd(const Reg32& r, const Opmask& k) { opKmov(k, r, true, 32); } +void kmovq(const Address& addr, const Opmask& k) { opKmov(k, addr, true, 64); } +void kmovq(const Opmask& k, const Operand& op) { opKmov(k, op, false, 64); } +void kmovw(const Address& addr, const Opmask& k) { opKmov(k, addr, true, 16); } +void kmovw(const Opmask& k, const Operand& op) { opKmov(k, op, false, 16); } +void kmovw(const Reg32& r, const Opmask& k) { opKmov(k, r, true, 16); } +void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } +void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } +void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } +void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } +void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } +void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } +void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } +void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } +void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } +void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } +void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } +void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } +void kshiftlb(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } +void kshiftld(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } +void kshiftlq(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } +void kshiftlw(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } +void kshiftrb(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } +void kshiftrd(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } +void kshiftrq(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } +void kshiftrw(const Opmask& r1, const Opmask& r2, uint8_t imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } +void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } +void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } +void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } +void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } +void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } +void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } +void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } +void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } +void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } +void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } +void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } +void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } +void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } +void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } +void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } +void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } +void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } +void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } +void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } +void vaddph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x58); } +void vaddsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x58); } +void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x03, imm); } +void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x03, imm); } +void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x65); } +void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x65); } +void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } +void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } +void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } +void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } +void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } +void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } +void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } +void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } +void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } +void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } +void vcmpeq_ospd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 16); } +void vcmpeq_osps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 16); } +void vcmpeq_ossd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 16); } +void vcmpeq_osss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 16); } +void vcmpeq_uqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 8); } +void vcmpeq_uqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 8); } +void vcmpeq_uqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 8); } +void vcmpeq_uqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 8); } +void vcmpeq_uspd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 24); } +void vcmpeq_usps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 24); } +void vcmpeq_ussd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 24); } +void vcmpeq_usss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 24); } +void vcmpeqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 0); } +void vcmpeqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 0); } +void vcmpeqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 0); } +void vcmpeqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 0); } +void vcmpfalse_ospd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 27); } +void vcmpfalse_osps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 27); } +void vcmpfalse_ossd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 27); } +void vcmpfalse_osss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 27); } +void vcmpfalsepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 11); } +void vcmpfalseps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 11); } +void vcmpfalsesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 11); } +void vcmpfalsess(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 11); } +void vcmpge_oqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 29); } +void vcmpge_oqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 29); } +void vcmpge_oqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 29); } +void vcmpge_oqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 29); } +void vcmpgepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 13); } +void vcmpgeps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 13); } +void vcmpgesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 13); } +void vcmpgess(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 13); } +void vcmpgt_oqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 30); } +void vcmpgt_oqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 30); } +void vcmpgt_oqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 30); } +void vcmpgt_oqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 30); } +void vcmpgtpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 14); } +void vcmpgtps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 14); } +void vcmpgtsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 14); } +void vcmpgtss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 14); } +void vcmple_oqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 18); } +void vcmple_oqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 18); } +void vcmple_oqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 18); } +void vcmple_oqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 18); } +void vcmplepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 2); } +void vcmpleps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 2); } +void vcmplesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 2); } +void vcmpless(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 2); } +void vcmplt_oqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 17); } +void vcmplt_oqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 17); } +void vcmplt_oqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 17); } +void vcmplt_oqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 17); } +void vcmpltpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 1); } +void vcmpltps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 1); } +void vcmpltsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 1); } +void vcmpltss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 1); } +void vcmpneq_oqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 12); } +void vcmpneq_oqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 12); } +void vcmpneq_oqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 12); } +void vcmpneq_oqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 12); } +void vcmpneq_ospd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 28); } +void vcmpneq_osps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 28); } +void vcmpneq_ossd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 28); } +void vcmpneq_osss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 28); } +void vcmpneq_uspd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 20); } +void vcmpneq_usps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 20); } +void vcmpneq_ussd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 20); } +void vcmpneq_usss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 20); } +void vcmpneqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 4); } +void vcmpneqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 4); } +void vcmpneqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 4); } +void vcmpneqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 4); } +void vcmpnge_uqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 25); } +void vcmpnge_uqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 25); } +void vcmpnge_uqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 25); } +void vcmpnge_uqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 25); } +void vcmpngepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 9); } +void vcmpngeps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 9); } +void vcmpngesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 9); } +void vcmpngess(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 9); } +void vcmpngt_uqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 26); } +void vcmpngt_uqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 26); } +void vcmpngt_uqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 26); } +void vcmpngt_uqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 26); } +void vcmpngtpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 10); } +void vcmpngtps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 10); } +void vcmpngtsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 10); } +void vcmpngtss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 10); } +void vcmpnle_uqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 22); } +void vcmpnle_uqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 22); } +void vcmpnle_uqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 22); } +void vcmpnle_uqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 22); } +void vcmpnlepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 6); } +void vcmpnleps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 6); } +void vcmpnlesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 6); } +void vcmpnless(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 6); } +void vcmpnlt_uqpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 21); } +void vcmpnlt_uqps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 21); } +void vcmpnlt_uqsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 21); } +void vcmpnlt_uqss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 21); } +void vcmpnltpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 5); } +void vcmpnltps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 5); } +void vcmpnltsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 5); } +void vcmpnltss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 5); } +void vcmpord_spd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 23); } +void vcmpord_sps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 23); } +void vcmpord_ssd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 23); } +void vcmpord_sss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 23); } +void vcmpordpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 7); } +void vcmpordps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 7); } +void vcmpordsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 7); } +void vcmpordss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 7); } +void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0xC2, imm); } +void vcmpph(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0xC2, imm); } +void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_0F|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0xC2, imm); } +void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_N8|T_F2|T_0F|T_EW1|T_SAE_Z|T_MUST_EVEX, 0xC2, imm); } +void vcmpsh(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_N2|T_F3|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0xC2, imm); } +void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_N4|T_F3|T_0F|T_EW0|T_SAE_Z|T_MUST_EVEX, 0xC2, imm); } +void vcmptrue_uspd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 31); } +void vcmptrue_usps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 31); } +void vcmptrue_ussd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 31); } +void vcmptrue_usss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 31); } +void vcmptruepd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 15); } +void vcmptrueps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 15); } +void vcmptruesd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 15); } +void vcmptruess(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 15); } +void vcmpunord_spd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 19); } +void vcmpunord_sps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 19); } +void vcmpunord_ssd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 19); } +void vcmpunord_sss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 19); } +void vcmpunordpd(const Opmask& k, const Xmm& x, const Operand& op) { vcmppd(k, x, op, 3); } +void vcmpunordps(const Opmask& k, const Xmm& x, const Operand& op) { vcmpps(k, x, op, 3); } +void vcmpunordsd(const Opmask& k, const Xmm& x, const Operand& op) { vcmpsd(k, x, op, 3); } +void vcmpunordss(const Opmask& k, const Xmm& x, const Operand& op) { vcmpss(k, x, op, 3); } +void vcomish(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_MAP5 | T_MUST_EVEX | T_EW0 | T_SAE_X | T_N2, 0x2F); } +void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x63); } +void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x8A); } +void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x8A); } +void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x63); } +void vcvtdq2ph(const Xmm& x, const Operand& op) { checkCvt4(x, op); opCvt(x, op, T_N16|T_N_VL|T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x5B); } +void vcvtne2ps2bf16(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x72); } +void vcvtpd2ph(const Xmm& x, const Operand& op) { opCvt5(x, op, T_N16|T_N_VL|T_66|T_MAP5|T_EW1|T_ER_Z|T_MUST_EVEX|T_B64, 0x5A); } +void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x7B); } +void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x79); } +void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x79); } +void vcvtph2dq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_MAP5|T_EW0|T_YMM|T_ER_Y|T_MUST_EVEX|T_B16, 0x5B); } +void vcvtph2pd(const Xmm& x, const Operand& op) { if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(x, 0, op, T_N4|T_N_VL|T_MAP5|T_EW0|T_YMM|T_SAE_X|T_MUST_EVEX|T_B16, 0x5A); } +void vcvtph2psx(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_MAP6|T_EW0|T_YMM|T_SAE_Y|T_MUST_EVEX|T_B16, 0x13); } +void vcvtph2qq(const Xmm& x, const Operand& op) { if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(x, 0, op, T_N4|T_N_VL|T_66|T_MAP5|T_EW0|T_YMM|T_ER_X|T_MUST_EVEX|T_B16, 0x7B); } +void vcvtph2udq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_MAP5|T_EW0|T_YMM|T_ER_Y|T_MUST_EVEX|T_B16, 0x79); } +void vcvtph2uqq(const Xmm& x, const Operand& op) { if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(x, 0, op, T_N4|T_N_VL|T_66|T_MAP5|T_EW0|T_YMM|T_ER_X|T_MUST_EVEX|T_B16, 0x79); } +void vcvtph2uw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x7D); } +void vcvtph2w(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x7D); } +void vcvtps2phx(const Xmm& x, const Operand& op) { checkCvt4(x, op); opCvt(x, op, T_N16|T_N_VL|T_66|T_MAP5|T_EW0|T_ER_Z|T_MUST_EVEX|T_B32, 0x1D); } +void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_0F|T_EW0|T_YMM|T_ER_Y|T_MUST_EVEX|T_B32, 0x7B); } +void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x79); } +void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_0F|T_EW0|T_YMM|T_ER_Y|T_MUST_EVEX|T_B32, 0x79); } +void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3|T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0xE6); } +void vcvtqq2ph(const Xmm& x, const Operand& op) { opCvt5(x, op, T_N16|T_N_VL|T_MAP5|T_EW1|T_ER_Z|T_MUST_EVEX|T_B64, 0x5B); } +void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x5B); } +void vcvtsd2sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_F2|T_MAP5|T_EW1|T_ER_X|T_MUST_EVEX, 0x5A); } +void vcvtsd2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N8|T_F2|T_0F|T_ER_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x79); } +void vcvtsh2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_F3|T_MAP5|T_EW0|T_SAE_X|T_MUST_EVEX, 0x5A); } +void vcvtsh2si(const Reg32e& r, const Operand& op) { uint64_t type = (T_N2|T_F3|T_MAP5|T_ER_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x2D); } +void vcvtsh2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_MAP6|T_EW0|T_SAE_X|T_MUST_EVEX, 0x13); } +void vcvtsh2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N2|T_F3|T_MAP5|T_ER_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x79); } +void vcvtsi2sh(const Xmm& x1, const Xmm& x2, const Operand& op) { if (!(x1.isXMM() && x2.isXMM() && op.isBit(32|64))) XBYAK_THROW(ERR_BAD_COMBINATION) uint64_t type = (T_F3|T_MAP5|T_ER_R|T_MUST_EVEX|T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8)); opVex(x1, &x2, op, type, 0x2A); } +void vcvtss2sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_MAP5|T_EW0|T_ER_X|T_MUST_EVEX, 0x1D); } +void vcvtss2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N4|T_F3|T_0F|T_ER_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x79); } +void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x7A); } +void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x78); } +void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x78); } +void vcvttph2dq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_F3|T_MAP5|T_EW0|T_YMM|T_SAE_Y|T_MUST_EVEX|T_B16, 0x5B); } +void vcvttph2qq(const Xmm& x, const Operand& op) { if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(x, 0, op, T_N4|T_N_VL|T_66|T_MAP5|T_EW0|T_YMM|T_SAE_X|T_MUST_EVEX|T_B16, 0x7A); } +void vcvttph2udq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_MAP5|T_EW0|T_YMM|T_SAE_Y|T_MUST_EVEX|T_B16, 0x78); } +void vcvttph2uqq(const Xmm& x, const Operand& op) { if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(x, 0, op, T_N4|T_N_VL|T_66|T_MAP5|T_EW0|T_YMM|T_SAE_X|T_MUST_EVEX|T_B16, 0x78); } +void vcvttph2uw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_MAP5|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x7C); } +void vcvttph2w(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_MAP5|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x7C); } +void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_0F|T_EW0|T_YMM|T_SAE_Y|T_MUST_EVEX|T_B32, 0x7A); } +void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x78); } +void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_66|T_0F|T_EW0|T_YMM|T_SAE_Y|T_MUST_EVEX|T_B32, 0x78); } +void vcvttsd2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N8|T_F2|T_0F|T_SAE_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x78); } +void vcvttsh2si(const Reg32e& r, const Operand& op) { uint64_t type = (T_N2|T_F3|T_MAP5|T_EW0|T_SAE_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x2C); } +void vcvttsh2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N2|T_F3|T_MAP5|T_EW0|T_SAE_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x78); } +void vcvttss2usi(const Reg32e& r, const Operand& op) { uint64_t type = (T_N4|T_F3|T_0F|T_SAE_X|T_MUST_EVEX) | (r.isREG(64) ? T_EW1 : T_EW0); opVex(r, &xm0, op, type, 0x78); } +void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_N8|T_N_VL|T_F3|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x7A); } +void vcvtudq2ph(const Xmm& x, const Operand& op) { checkCvt4(x, op); opCvt(x, op, T_N16|T_N_VL|T_F2|T_MAP5|T_EW0|T_ER_Z|T_MUST_EVEX|T_B32, 0x7A); } +void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2|T_0F|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x7A); } +void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3|T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x7A); } +void vcvtuqq2ph(const Xmm& x, const Operand& op) { opCvt5(x, op, T_N16|T_N_VL|T_F2|T_MAP5|T_EW1|T_ER_Z|T_MUST_EVEX|T_B64, 0x7A); } +void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2|T_0F|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x7A); } +void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtusi2sh(const Xmm& x1, const Xmm& x2, const Operand& op) { if (!(x1.isXMM() && x2.isXMM() && op.isBit(32|64))) XBYAK_THROW(ERR_BAD_COMBINATION) uint64_t type = (T_F3|T_MAP5|T_ER_R|T_MUST_EVEX|T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8)); opVex(x1, &x2, op, type, 0x7B); } +void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtuw2ph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2|T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x7D); } +void vcvtw2ph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3|T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x7D); } +void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x42, imm); } +void vdivph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x5E); } +void vdivsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x5E); } +void vdpbf16ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x52); } +void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } +void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } +void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x88); } +void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x88); } +void vextractf32x4(const Operand& op, const Ymm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::XMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N16|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x19, imm); } +void vextractf32x8(const Operand& op, const Zmm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N32|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x1B, imm); } +void vextractf64x2(const Operand& op, const Ymm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::XMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N16|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x19, imm); } +void vextractf64x4(const Operand& op, const Zmm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N32|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x1B, imm); } +void vextracti32x4(const Operand& op, const Ymm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::XMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N16|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x39, imm); } +void vextracti32x8(const Operand& op, const Zmm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N32|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x3B, imm); } +void vextracti64x2(const Operand& op, const Ymm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::XMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N16|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x39, imm); } +void vextracti64x4(const Operand& op, const Zmm& r, uint8_t imm) { if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r, 0, op, T_N32|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x3B, imm); } +void vfcmaddcph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x56); } +void vfcmulcph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F2|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0xD6); } +void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x54, imm); } +void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x54, imm); } +void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F3A|T_EW1|T_SAE_Z|T_MUST_EVEX, 0x55, imm); } +void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_EW0|T_SAE_Z|T_MUST_EVEX, 0x55, imm); } +void vfmadd132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x98); } +void vfmadd132sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0x99); } +void vfmadd213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xA8); } +void vfmadd213sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xA9); } +void vfmadd231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xB8); } +void vfmadd231sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xB9); } +void vfmaddcph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x56); } +void vfmaddsub132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x96); } +void vfmaddsub213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xA6); } +void vfmaddsub231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xB6); } +void vfmsub132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x9A); } +void vfmsub132sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0x9B); } +void vfmsub213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xAA); } +void vfmsub213sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xAB); } +void vfmsub231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xBA); } +void vfmsub231sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xBB); } +void vfmsubadd132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x97); } +void vfmsubadd213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xA7); } +void vfmsubadd231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xB7); } +void vfmulcph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0xD6); } +void vfnmadd132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x9C); } +void vfnmadd132sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0x9D); } +void vfnmadd213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xAC); } +void vfnmadd213sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xAD); } +void vfnmadd231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xBC); } +void vfnmadd231sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xBD); } +void vfnmsub132ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x9E); } +void vfnmsub132sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0x9F); } +void vfnmsub213ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xAE); } +void vfnmsub213sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xAF); } +void vfnmsub231ph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0xBE); } +void vfnmsub231sh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0xBF); } +void vfpclasspd(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isBit(128|256|512)) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } +void vfpclassph(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isBit(128|256|512)) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k.changeBit(op.getBit()), 0, op, T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B16, 0x66, imm); } +void vfpclassps(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isBit(128|256|512)) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } +void vfpclasssd(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isXMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } +void vfpclasssh(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isXMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k, 0, op, T_0F3A | T_MUST_EVEX | T_EW0 | T_N2, 0x67, imm); } +void vfpclassss(const Opmask& k, const Operand& op, uint8_t imm) { if (!op.isXMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE) opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } +void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_VSIB, 0x92, 1); } +void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_VSIB, 0x92, 0); } +void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_VSIB, 0x93, 0); } +void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_VSIB, 0x93, 2); } +void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x42); } +void vgetexpph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_MAP6|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x42); } +void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x42); } +void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_SAE_X|T_MUST_EVEX, 0x43); } +void vgetexpsh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_SAE_X|T_MUST_EVEX, 0x43); } +void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_SAE_X|T_MUST_EVEX, 0x43); } +void vgetmantpd(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x26, imm); } +void vgetmantph(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x26, imm); } +void vgetmantps(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x26, imm); } +void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F3A|T_EW1|T_SAE_X|T_MUST_EVEX, 0x27, imm); } +void vgetmantsh(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N2|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x27, imm); } +void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x27, imm); } +void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8_t imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N16|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x18, imm); } +void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8_t imm) {if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N32|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x1A, imm); } +void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8_t imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N16|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x18, imm); } +void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8_t imm) {if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N32|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x1A, imm); } +void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8_t imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N16|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x38, imm); } +void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8_t imm) {if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N32|T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x3A, imm); } +void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8_t imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N16|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x38, imm); } +void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8_t imm) {if (!op.is(Operand::MEM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(r1, &r2, op, T_N32|T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x3A, imm); } +void vmaxph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x5F); } +void vmaxsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x5F); } +void vminph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x5D); } +void vminsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x5D); } +void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3|T_0F|T_EW1|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX|T_M_K, 0x7F); } +void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2|T_0F|T_EW0|T_YMM|T_ER_X|T_ER_Y|T_ER_Z|T_MUST_EVEX, 0x6F); } +void vmovsh(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_N2|T_F3|T_MAP5|T_EW0|T_MUST_EVEX|T_M_K, 0x11); } +void vmovsh(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N2|T_F3|T_MAP5|T_EW0|T_MUST_EVEX, 0x10); } +void vmovsh(const Xmm& x1, const Xmm& x2, const Xmm& x3) { opAVX_X_X_XM(x1, x2, x3, T_N2|T_F3|T_MAP5|T_EW0|T_MUST_EVEX, 0x10); } +void vmovw(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_N2|T_66|T_MAP5|T_MUST_EVEX, 0x7E); } +void vmovw(const Reg32e& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, r, T_N2|T_66|T_MAP5|T_MUST_EVEX, 0x7E); } +void vmovw(const Xmm& x, const Operand& op) { if (!op.isREG(32|64) && !op.isMEM()) XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, xm0, op, T_N2|T_66|T_MAP5|T_MUST_EVEX, 0x6E); } +void vmulph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x59); } +void vmulsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x59); } +void vp2intersectd(const Opmask& k, const Xmm& x, const Operand& op) { if (k.getOpmaskIdx() != 0) XBYAK_THROW(ERR_OPMASK_IS_ALREADY_SET) opAVX_K_X_XM(k, x, op, T_F2 | T_0F38 | T_YMM | T_EVEX | T_EW0 | T_B32, 0x68); } +void vp2intersectq(const Opmask& k, const Xmm& x, const Operand& op) { if (k.getOpmaskIdx() != 0) XBYAK_THROW(ERR_OPMASK_IS_ALREADY_SET) opAVX_K_X_XM(k, x, op, T_F2 | T_0F38 | T_YMM | T_EVEX | T_EW1 | T_B64, 0x68); } +void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } +void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } +void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } +void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0xDB); } +void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0xDF); } +void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0xDF); } +void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0xDB); } +void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x66); } +void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x64); } +void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x64); } +void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x66); } +void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x7A); } +void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x7C); } +void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } +void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } +void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x7B); } +void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x3F, imm); } +void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x1F, imm); } +void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_YMM|T_MUST_EVEX, 0x74); } +void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_YMM|T_MUST_EVEX|T_B32, 0x76); } +void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x29); } +void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_YMM|T_MUST_EVEX, 0x75); } +void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_YMM|T_MUST_EVEX, 0x64); } +void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x66); } +void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x37); } +void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F|T_YMM|T_MUST_EVEX, 0x65); } +void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x1F, imm); } +void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX, 0x3E, imm); } +void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x1E, imm); } +void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x1E, imm); } +void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x3E, imm); } +void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8_t imm) { opAVX_K_X_XM(k, x, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX, 0x3F, imm); } +void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x8B); } +void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x8B); } +void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0xC4); } +void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0xC4); } +void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x8D); } +void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x75); } +void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x76); } +void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x77); } +void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x77); } +void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x76); } +void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x75); } +void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x7D); } +void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x7E); } +void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x7F); } +void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x7F); } +void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x7E); } +void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x7D); } +void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x8D); } +void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1|T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x62); } +void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x89); } +void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x89); } +void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2|T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x62); } +void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_VSIB, 0x90, 0); } +void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_VSIB, 0x90, 1); } +void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_VSIB, 0x91, 2); } +void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_VSIB, 0x91, 0); } +void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x44); } +void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x44); } +void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x3D); } +void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x3F); } +void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x39); } +void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x3B); } +void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } +void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } +void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x31, false); } +void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x33, true); } +void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } +void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } +void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } +void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } +void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } +void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x32, false); } +void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x35, true); } +void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x34, false); } +void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x21, false); } +void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x23, true); } +void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x22, false); } +void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x25, true); } +void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x24, false); } +void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x20, true); } +void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x11, false); } +void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x13, true); } +void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x12, false); } +void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x15, true); } +void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x14, false); } +void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x10, true); } +void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } +void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8|T_N_VL|T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K, 0x30, true); } +void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x40); } +void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x83); } +void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x54); } +void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x55); } +void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x55); } +void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x54); } +void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0xEB); } +void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0xEB); } +void vprold(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x72, imm); } +void vprolq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x72, imm); } +void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x15); } +void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x15); } +void vprord(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x72, imm); } +void vprorq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x72, imm); } +void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x14); } +void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x14); } +void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA0, 0); } +void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA0, 1); } +void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA1, 2); } +void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA1, 0); } +void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x71, imm); } +void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x71, imm); } +void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x71); } +void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x71); } +void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x70); } +void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x70, imm); } +void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x73, imm); } +void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x73, imm); } +void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x73); } +void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x73); } +void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x72); } +void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX, 0x72, imm); } +void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } +void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x12); } +void vpsraq(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x72, imm); } +void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16|T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX, 0xE2); } +void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x46); } +void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x11); } +void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x10); } +void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x25, imm); } +void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x25, imm); } +void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x26); } +void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x27); } +void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x27); } +void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x26); } +void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x26); } +void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x27); } +void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x27); } +void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x26); } +void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0xEF); } +void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0xEF); } +void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x50, imm); } +void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x50, imm); } +void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F3A|T_EW1|T_SAE_X|T_MUST_EVEX, 0x51, imm); } +void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x51, imm); } +void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x4C); } +void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x4C); } +void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX, 0x4D); } +void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX, 0x4D); } +void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } +void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } +void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_SAE_X|T_MUST_EVEX, 0xCB); } +void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_SAE_X|T_MUST_EVEX, 0xCB); } +void vrcpph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_MAP6|T_EW0|T_YMM|T_MUST_EVEX|T_B16, 0x4C); } +void vrcpsh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_MUST_EVEX, 0x4D); } +void vreducepd(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x56, imm); } +void vreduceph(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x56, imm); } +void vreduceps(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x56, imm); } +void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F3A|T_EW1|T_SAE_X|T_MUST_EVEX, 0x57, imm); } +void vreducesh(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N2|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x57, imm); } +void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x57, imm); } +void vrndscalepd(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW1|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B64, 0x09, imm); } +void vrndscaleph(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B16, 0x08, imm); } +void vrndscaleps(const Xmm& x, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(x, op, T_66|T_0F3A|T_EW0|T_YMM|T_SAE_Z|T_MUST_EVEX|T_B32, 0x08, imm); } +void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F3A|T_EW1|T_SAE_X|T_MUST_EVEX, 0x0B, imm); } +void vrndscalesh(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N2|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x0A, imm); } +void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F3A|T_EW0|T_SAE_X|T_MUST_EVEX, 0x0A, imm); } +void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_B64, 0x4E); } +void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_B32, 0x4E); } +void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x4F); } +void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX, 0x4F); } +void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } +void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } +void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_SAE_X|T_MUST_EVEX, 0xCD); } +void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_SAE_X|T_MUST_EVEX, 0xCD); } +void vrsqrtph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66|T_MAP6|T_EW0|T_YMM|T_MUST_EVEX|T_B16, 0x4E); } +void vrsqrtsh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_MUST_EVEX, 0x4F); } +void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW1|T_YMM|T_ER_Z|T_MUST_EVEX|T_B64, 0x2C); } +void vscalefph(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_MAP6|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x2C); } +void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66|T_0F38|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B32, 0x2C); } +void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8|T_66|T_0F38|T_EW1|T_ER_X|T_MUST_EVEX, 0x2D); } +void vscalefsh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_66|T_MAP6|T_EW0|T_ER_X|T_MUST_EVEX, 0x2D); } +void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4|T_66|T_0F38|T_EW0|T_ER_X|T_MUST_EVEX, 0x2D); } +void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA2, 1); } +void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA2, 0); } +void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8|T_66|T_0F38|T_EW1|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4|T_66|T_0F38|T_EW0|T_MUST_EVEX|T_M_K|T_VSIB, 0xC7, Operand::ZMM); } +void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8|T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA3, 0); } +void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4|T_66|T_0F38|T_EW0|T_YMM|T_MUST_EVEX|T_M_K|T_VSIB, 0xA3, 2); } +void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } +void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } +void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } +void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } +void vsqrtph(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_MAP5|T_EW0|T_YMM|T_ER_Z|T_MUST_EVEX|T_B16, 0x51); } +void vsqrtsh(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N2|T_F3|T_MAP5|T_EW0|T_ER_X|T_MUST_EVEX, 0x51); } +void vsubph(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_EW0 | T_YMM | T_MUST_EVEX | T_ER_Z | T_B16, 0x5C); } +void vsubsh(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_MAP5 | T_F3 | T_EW0 | T_MUST_EVEX | T_ER_X | T_N2, 0x5C); } +void vucomish(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_MAP5 | T_MUST_EVEX | T_EW0 | T_SAE_X | T_N2, 0x2E); } +#ifdef XBYAK64 +void kmovq(const Reg64& r, const Opmask& k) { opKmov(k, r, true, 64); } +void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66|T_0F38|T_EW1|T_YMM|T_MUST_EVEX, 0x7C); } +#endif +#endif diff --git a/addon/aocl_gemm/aocl_bf16_type.h b/addon/aocl_gemm/aocl_bf16_type.h index f8b2fd431a..6203267188 100644 --- a/addon/aocl_gemm/aocl_bf16_type.h +++ b/addon/aocl_gemm/aocl_bf16_type.h @@ -1,9 +1,11 @@ - /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -15,6 +17,7 @@ - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -26,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #ifndef AOCL_GEMM_HALF_PRECISION_TYPE_H #define AOCL_GEMM_HALF_PRECISION_TYPE_H diff --git a/addon/aocl_gemm/aocl_eltwise_ops.c b/addon/aocl_gemm/aocl_eltwise_ops.c new file mode 100644 index 0000000000..72faa9b671 --- /dev/null +++ b/addon/aocl_gemm/aocl_eltwise_ops.c @@ -0,0 +1,275 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_eltwise_ops_interface_apis.h" +#include "aocl_gemm_check.h" +#include "lpgemm_types.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_utils.h" +#include "lpgemm_config.h" +#include "lpgemm_post_ops.h" + +BLIS_INLINE void aocl_eltwise_ops_bf16of32_base + ( + const char order, + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const bfloat16* a, + const dim_t lda, + float* b, + const dim_t ldb, + aocl_post_op* post_op_unparsed, + AOCL_STORAGE_TYPE c_downscale + ) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( ( is_column_major == TRUE ) || + ( bli_is_trans( blis_transa ) ) || + ( bli_is_trans( blis_transb ) ) ) + { + bli_print_msg("Column major and transpose not supported.", + __FILE__, __LINE__); + return; + } + + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + err_t err = lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + NULL, ( void* )( &order ), + m, n + ); + if( err != BLIS_SUCCESS ) return; + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); + + lpgemm_eltwise_ops_cntx_t* lcntx_g = + lpgemm_eltwise_ops_get_global_cntx_obj( BF16OF32 ); + +#ifdef BLIS_ENABLE_OPENMP + + lpgemm_eltwise_ops_bf16of32_openmp_thread_decorator + ( + m, n, + a, rs_a, cs_a, + b, rs_b, cs_b, + &rntm_g, lcntx_g, + post_op_list, c_downscale + ); +#else + lpgemm_eltwise_ops_bf16of32_thread_decorator + ( + m, n, + a, rs_a, cs_a, + b, rs_b, cs_b, + &rntm_g, lcntx_g, + post_op_list, c_downscale + ); +#endif +} + +AOCL_UTIL_ELTWISE_OPS(bfloat16,float,bf16of32) +{ + AOCL_UTIL_ELTWISE_OPS_CHECK + ( + "bf16of32", + order, transa, transb, + m, n, + a, lda, + b, ldb + ); + + aocl_eltwise_ops_bf16of32_base + ( + order, transa, transb, + m, n, + a, lda, + b, ldb, + post_op_unparsed, F32 + ); +} + +AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16) +{ + AOCL_UTIL_ELTWISE_OPS_CHECK + ( + "bf16obf16", + order, transa, transb, + m, n, + a, lda, + b, ldb + ); + + // Even though b matrix is typecasted to float*, actual load/store + // and matrix traversal will happen as bfloat16* type. This typecast + // is only to ensure code is reused. + aocl_eltwise_ops_bf16of32_base + ( + order, transa, transb, + m, n, + a, lda, + ( float* )b, ldb, + post_op_unparsed, BF16 + ); +} + +AOCL_UTIL_ELTWISE_OPS(float,float,f32of32) +{ + AOCL_UTIL_ELTWISE_OPS_CHECK + ( + "f32of32", + order, transa, transb, + m, n, + a, lda, + b, ldb + ); + + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( ( is_column_major == TRUE ) || + ( bli_is_trans( blis_transa ) ) || + ( bli_is_trans( blis_transb ) ) ) + { + bli_print_msg("Column major and transpose not supported.", + __FILE__, __LINE__); + return; + } + + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + err_t err = lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + NULL, ( void* )( &order ), + m, n + ); + if( err != BLIS_SUCCESS ) return; + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); + + lpgemm_eltwise_ops_cntx_t* lcntx_g = + lpgemm_eltwise_ops_get_global_cntx_obj( F32OF32 ); + +#ifdef BLIS_ENABLE_OPENMP + + lpgemm_eltwise_ops_f32of32_openmp_thread_decorator + ( + m, n, + a, rs_a, cs_a, + b, rs_b, cs_b, + &rntm_g, lcntx_g, + post_op_list, F32 + ); +#else + lpgemm_eltwise_ops_f32of32_thread_decorator + ( + m, n, + a, rs_a, cs_a, + b, rs_b, cs_b, + &rntm_g, lcntx_g, + post_op_list, F32 + ); +#endif +} \ No newline at end of file diff --git a/addon/aocl_gemm/aocl_eltwise_ops_interface_apis.h b/addon/aocl_gemm/aocl_eltwise_ops_interface_apis.h new file mode 100644 index 0000000000..8f057d7fce --- /dev/null +++ b/addon/aocl_gemm/aocl_eltwise_ops_interface_apis.h @@ -0,0 +1,60 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_ELTWISE_OPS_INTERFACE_H +#define AOCL_ELTWISE_OPS_INTERFACE_H + +#include "aocl_gemm_post_ops.h" +#include "aocl_bf16_type.h" + +#define AOCL_UTIL_ELTWISE_OPS(A_type,B_type,LP_SFX) \ +BLIS_EXPORT_ADDON void aocl_gemm_eltwise_ops_ ## LP_SFX \ + ( \ + const char order, \ + const char transa, \ + const char transb, \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t lda, \ + B_type* b, \ + const dim_t ldb, \ + aocl_post_op* post_op_unparsed \ + ) \ + +AOCL_UTIL_ELTWISE_OPS(bfloat16,float,bf16of32); +AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16); +AOCL_UTIL_ELTWISE_OPS(float,float,f32of32); + +#endif // AOCL_ELTWISE_OPS_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index 027f895591..070f05bc7d 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,17 +38,24 @@ #include "aocl_gemm_post_ops.h" #include "aocl_gemm_interface_apis.h" #include "aocl_util_interface_apis.h" +#include "aocl_eltwise_ops_interface_apis.h" #include "aocl_bf16_type.h" #include "lpgemm_config.h" #include "lpgemm_post_ops.h" #include "lpgemm_kernels.h" +#include "lpgemm_eltwise_ops_kernels.h" #include "lpgemm_utils_kernels.h" #include "lpgemm_pack_bf16.h" #include "lpgemm_packb_s16.h" +#include "lpgemm_packa_s16.h" #include "lpgemm_packa.h" #include "lpgemm_packb.h" #include "lpgemm_packa_s8.h" #include "lpgemm_packb_s8.h" #include "lpgemm_packb_s8s16.h" - +#include "lpgemm_pack_f32.h" +#include "lpgemm_jit_typedefs.h" +#ifdef LPGEMM_BF16_JIT +#include "lpgemm_jit_c_connector.h" +#endif #endif // BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c index de709e8f90..146fd97d8f 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -73,11 +73,32 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) // loaded; and since k_dim needs to be at least 2, having n_dim at least 16 // should give 2x16=32 elements, enough for 1 zmm register.The padding is // not rounded to NR (=64), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n( n, 16 ); +#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT))) + dim_t n_reorder; + + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + } // Extra space since packing does length in multiples of 2. + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 2 ); + } +#else + dim_t n_reorder = make_multiple_of_n( n, 16 );; dim_t k_reorder = make_multiple_of_n( k, 2 ); - +#endif siz_t size_req = sizeof( int16_t ) * k_reorder * n_reorder; return size_req; @@ -134,7 +155,23 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) { return; // A reorder not supported. } - +#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT))) + if( n == 1 ) + { + if( rs_b == 1 ) + { + memcpy( reorder_buf_addr, input_buf_addr, ( k * sizeof( bfloat16 ) ) ); + } + else + { + for( dim_t k0 = 0; k0 < k; k0++ ) + { + reorder_buf_addr[k0] = input_buf_addr[k0*rs_b]; + } + } + return; + } +#endif // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; @@ -157,3 +194,135 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) reorderb_nr64_bf16bf16f32of32( &b, &b_reorder, &rntm_g, lcntx_g ); } + +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16s4f32of32) +{ + if ((k <= 0) || (n <= 0)) + { + return 0; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if (bli_cpuid_is_avx512bf16_supported() == FALSE) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", + __FILE__, __LINE__); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return 0; // A reorder not supported. + } + + dim_t n_reorder; + + /*if (n == 1) + { + n_reorder = 1; + } + else*/ + { + n_reorder = make_multiple_of_n(n, 16); + } + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder; + /*if (n == 1) + { + k_reorder = k; + } + else*/ + { + k_reorder = make_multiple_of_n(k, 2); + } + + siz_t size_req = (sizeof(int8_t) * k_reorder * n_reorder)/2; + return size_req; +} + +AOCL_GEMM_REORDER(int8_t, bf16s4f32of32) +{ + trans_t blis_trans; + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + + if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || + (k <= 0) || (n <= 0) || (bli_is_notrans(blis_trans) && (ldb < n)) || + (bli_is_trans(blis_trans) && (ldb < k))) + { + return; // Error. + } + + inc_t rs_b, cs_b; + if ((order == 'r') || (order == 'R')) + { + rs_b = bli_is_notrans(blis_trans) ? ldb : 1; + cs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + } + else if ((order == 'c') || (order == 'C')) + { + rs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + cs_b = bli_is_notrans(blis_trans) ? ldb : 1; + } + else + { + return; // Error + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if (bli_cpuid_is_avx512bf16_supported() == FALSE) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", + __FILE__, __LINE__); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return; // A reorder not supported. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); + + lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32); + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = (void *)input_buf_addr; + b.rs = rs_b; + b.cs = cs_b; + b.width = n; + b.length = k; + + reorderb_nr64_bf16s4f32of32(&b, &b_reorder, &rntm_g, lcntx_g); +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index 897facfbda..d266dfd051 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -47,19 +47,6 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) trans_t blis_transa; trans_t blis_transb; - // There is this use case where lpgemm will be compiled using gcc9.4 - // (where bf16 ISA is not supported), but deployed on a zen4+ sustem - // (which supports bf16 ISA). Here the bf16 kernels will be concealed - // and not compiled, and subsequently this api should error out and - // return early, even if bf16 ISA is supported by machine. -#if defined( BLIS_GCC ) && ( __GNUC__ < 10 ) - { - bli_print_msg("bf16bf16f32obf16 compiled using a compiler not " - "supporting BF16 ISA.", __FILE__, __LINE__ ); - return; // Error. - } -#endif - // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) { @@ -85,6 +72,18 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) c, ldc ); +#ifdef LPGEMM_BF16_JIT + dim_t num_N_variants = ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1; + for( dim_t m = 0; m < LPGEMM_BF16_MR; m++ ) + for( dim_t n = 0; n < num_N_variants; n++ ) + if( lpgemm_get_jit_kernel(m, n ) == NULL ) + { + bli_print_msg(" Could not generate bf16bf16f32obf16 " + " kernels using JIT.", __FILE__, __LINE__ ); + return; + } +#endif + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); @@ -166,7 +165,8 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index 0ca2602898..cd9c8b7a50 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -47,19 +47,6 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) trans_t blis_transa; trans_t blis_transb; - // There is this use case where lpgemm will be compiled using gcc9.4 - // (where bf16 ISA is not supported), but deployed on a zen4+ sustem - // (which supports bf16 ISA). Here the bf16 kernels will be concealed - // and not compiled, and subsequently this api should error out and - // return early, even if bf16 ISA is supported by machine. -#if defined( BLIS_GCC ) && ( __GNUC__ < 10 ) - { - bli_print_msg("bf16bf16f32of32 compiled using a compiler not " - "supporting BF16 ISA.", __FILE__, __LINE__ ); - return; // Error. - } -#endif - // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) { @@ -85,6 +72,19 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) c, ldc ); +#ifdef LPGEMM_BF16_JIT + dim_t num_N_variants = ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1; + for( dim_t m = 0; m < LPGEMM_BF16_MR; m++ ) + for( dim_t n = 0; n < num_N_variants; n++ ) + if( lpgemm_get_jit_kernel(m, n) == NULL ) + { + bli_print_msg(" Could not generate bf16bf16f32of32 " + " kernels using JIT.", __FILE__, __LINE__ ); + return; + } +#endif + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); @@ -166,7 +166,8 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_bf16s4f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16s4f32of32.c new file mode 100644 index 0000000000..7451ab3cd0 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16s4f32of32.c @@ -0,0 +1,398 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if (bli_cpuid_is_avx512bf16_supported() == FALSE) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", + __FILE__, __LINE__); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // check for validity of params. + AOCL_GEMM_CHECK( + "bf16s4f32obf16", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; + + if (bli_is_trans(blis_transa)) + { + rs_a = 1; + cs_a = lda; + } + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) + { + rs_b = 1; + cs_b = ldb; + } + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__); + return; + } + + // From 5-loop function point of view + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } + + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) + { + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; + } + + // Convert post op struct to post op linked list format. + lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS]; + err_t err = lpgemm_translate_to_pre_ops_list + ( + post_op_unparsed->pre_ops, + pre_op_list, + m, n, k + ); + if (err != BLIS_SUCCESS) + return; + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + err = lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, + post_op_list, + (void *)c, (void *)(&order), + m, n + ); + if (err != BLIS_SUCCESS) + return; + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); + + lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32); + +#ifdef BLIS_ENABLE_OPENMP + + if (is_column_major == TRUE) + { + // Swapping inputs not possible in case of mixed precision. + bli_print_msg(" column major not supported yet in bf16s4f32o.", __FILE__, __LINE__); + return; + } + else + { + lpgemm_bf16s4f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, pre_op_list, + post_op_list, F32 + ); + } +#else + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + // Swapping inputs not possible in case of mixed precision. + bli_print_msg(" column major not supported yet in bf16s4f32o.", __FILE__, __LINE__); + return; + } + else + { + lpgemm_bf16s4f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, pre_op_list, + post_op_list, F32 + ); + } +#endif +} + +AOCL_GEMM_MATMUL(bfloat16, int8_t, bfloat16, float, bf16s4f32obf16) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if (bli_cpuid_is_avx512bf16_supported() == FALSE) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", + __FILE__, __LINE__); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // check for validity of params. + AOCL_GEMM_CHECK( + "bf16s4f32of32", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; + + if (bli_is_trans(blis_transa)) + { + rs_a = 1; + cs_a = lda; + } + + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) + { + rs_b = 1; + cs_b = ldb; + } + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__); + return; + } + + // From 5-loop function point of view + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } + + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) + { + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; + } + + // Convert post op struct to post op linked list format. + lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS]; + err_t err = lpgemm_translate_to_pre_ops_list( + post_op_unparsed->pre_ops, pre_op_list, + m, n, k); + + if (err != BLIS_SUCCESS) + return; + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + err = lpgemm_translate_to_post_ops_list( + post_op_unparsed, post_op_list, + (void *)c, (void *)(&order), + m, n); + + if (err != BLIS_SUCCESS) + return; + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); + + lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32); + +#ifdef BLIS_ENABLE_OPENMP + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + // Swapping inputs not possible in case of mixed precision. + bli_print_msg(" column major not supported yet in bf16s4f32o.", __FILE__, __LINE__); + return; + } + else + { + lpgemm_bf16s4f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (float *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, pre_op_list, + post_op_list, BF16 + ); + } +#else + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + // Swapping inputs not possible in case of mixed precision. + bli_print_msg(" column major not supported yet in bf16s4f32o.", __FILE__, __LINE__); + return; + } + else + { + lpgemm_bf16s4f32of32_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (float*)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, pre_op_list, + post_op_list, BF16); + } +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_check.h b/addon/aocl_gemm/aocl_gemm_check.h index a49fb78007..d47591906b 100644 --- a/addon/aocl_gemm/aocl_gemm_check.h +++ b/addon/aocl_gemm/aocl_gemm_check.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +32,6 @@ */ -// yet to add validity check for postops #define AOCL_GEMM_CHECK( op_str, \ order, transa, transb, \ m, n, k, \ @@ -102,3 +101,56 @@ return; \ } \ } + +#define AOCL_UTIL_ELTWISE_OPS_CHECK( op_str, \ + order, transa, transb, \ + m, n, \ + a, lda, \ + b, ldb \ + ) \ +{ \ + int32_t info = 0; \ + bool col_stored, row_stored; \ + bool nota, notb, ta, tb; \ + \ + col_stored = ( order == 'c' ) || ( order == 'C' ); \ + row_stored = ( order == 'r' ) || ( order == 'R' ); \ + \ + nota = ( transa == 'n' ) || ( transa == 'N' ); \ + notb = ( transb == 'n' ) || ( transb == 'N' ); \ + \ + ta = ( transa == 't' ) || ( transa == 'T' ); \ + tb = ( transb == 't' ) || ( transb == 'T' ); \ + \ + if( ( order != 'r') && ( order != 'R' ) && ( order != 'c' ) && ( order != 'C' ) ) \ + info = 1; \ + else if( ( transa != 'n' ) && ( transa != 'N' ) && ( transa != 't' ) && ( transa != 'T' ) ) \ + info = 2; \ + else if( ( transb != 'n' ) && ( transb != 'N' ) && ( transb != 't' ) && ( transb != 'T' ) ) \ + info = 3; \ + else if ( m <= 0 ) \ + info = 4; \ + else if ( n <= 0 ) \ + info = 5; \ + else if ( a == NULL ) \ + info = 6; \ + else if ( row_stored && ( ( nota && ( lda < n ) ) || ( ta && ( lda < m ) ) ) ) \ + info = 7; \ + else if ( col_stored && ( ( nota && ( lda < m ) ) || ( ta && ( lda < n ) ) ) ) \ + info = 8; \ + else if ( b == NULL ) \ + info = 9; \ + else if ( row_stored && ( ( notb && ( ldb < n ) ) || ( tb && ( ldb < m ) ) ) ) \ + info = 10; \ + else if ( col_stored && ( ( notb && ( ldb < m ) ) || ( tb && ( ldb < n ) ) ) ) \ + info = 11; \ + \ + if( info != 0 ) \ + { \ + char print_msg[ 100 ]; \ + \ + sprintf( print_msg, "** On entry to %6s, parameter number %2i had an illegal value", op_str, info); \ + bli_print_msg(print_msg, __FILE__, __LINE__); \ + return; \ + } \ +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 107b651b71..e3db6e3864 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -80,24 +80,26 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "Input matrix transpose not supported."); - return; // Error. - } - bool is_row_major = ( ( order == 'r' ) || ( order == 'R' ) ); bool is_column_major = ( ( order == 'c' ) || ( order == 'C' ) ); // The strides are set assuming a row major kernel. - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + + if(bli_is_trans(blis_transa)) { + rs_a = 1; + cs_a = lda; + } + + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if(bli_is_trans(blis_transb)) { + rs_b = 1; + cs_b = ldb; + } + const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -107,11 +109,19 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); - if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) + // Reordered A not supported now. + if ( ( is_row_major == TRUE ) && ( mtag_a == REORDERED ) ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "Reordered B matrix not supported in column major case."); - return; + bli_print_msg(" Reordering of A matrix is not supported.", __FILE__, __LINE__ ); + return; // Error. + } + + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( ( mtag_b == REORDERED ) || (mtag_a == REORDERED ) ) ) + { + bli_print_msg(" Reordering of column major matrices is not supported.", + __FILE__, __LINE__ ); + return; //Error } // By default enable packing for B matrix. Before the 5 loop, based on @@ -127,19 +137,17 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) mtag_a = PACK; } - // Reordered A not supported now. - if ( ( is_row_major == TRUE ) && ( mtag_a == REORDERED ) ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if( ( is_row_major == TRUE ) && ( bli_is_trans(blis_transa ) ) ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "A matrix reordering not supported for row major inputs."); - return; // Error. + mtag_a = PACK; } // Inputs swapped in column major, A becomes B from kernel point of view. - else if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) + else if ( ( is_column_major == TRUE ) && ( bli_is_trans(blis_transb ) ) ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "B matrix reordering not supported for column major inputs."); - return; // Error. + mtag_b = PACK; } // Convert post op struct to post op linked list format. @@ -147,7 +155,8 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; @@ -159,7 +168,6 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 ); - #ifdef BLIS_ENABLE_OPENMP // The lpgemm_cntx_t argument will be NULL for f32 since it still uses // BLIS cntx_t internally. Its a workaround for now and will be replaced diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c index 3b801ce0db..d8e3ccb7e8 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -74,7 +74,15 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); // Extra space since packing does width in multiples of NR. - const dim_t n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + dim_t n_reorder; + if(n == 1) + { + //When n == 1, LPGEMV doesn't expect B to be reordered. + n_reorder = 1; + }else + { + n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + } siz_t size_req = sizeof( float ) * k * n_reorder; @@ -84,12 +92,34 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) // Pack B into row stored column panels. AOCL_GEMM_REORDER(float,f32f32f32of32) { + trans_t blis_trans; + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || - ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + ( k <= 0 ) || ( n <= 0 ) || ( bli_is_notrans( blis_trans ) && ( ldb < n ) ) || + ( bli_is_trans( blis_trans ) && ( ldb < k ) ) ) { return; // Error. } + // Only supports row major packing now. + inc_t rs_b, cs_b; + if ((order == 'r') || (order == 'R')) + { + rs_b = bli_is_notrans(blis_trans) ? ldb : 1; + cs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + } + else if ((order == 'c') || (order == 'C')) + { + rs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + cs_b = bli_is_notrans(blis_trans) ? ldb : 1; + } + else + { + return; // Error + } + // Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it. if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { @@ -97,7 +127,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) "cannot perform f32f32f32 gemm.", __FILE__, __LINE__ ); return; // Error. } - + /* Initialize BLIS. */ bli_init_auto(); @@ -121,10 +151,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); - // Only supports row major packing now. - inc_t rs_b = ldb; - inc_t cs_b = 1; - inc_t rs_p = NR; float one_local = *PASTEMAC(s,1); @@ -144,6 +170,23 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) dim_t n_threads = bli_rntm_num_threads( &rntm_g ); n_threads = ( n_threads > 0 ) ? n_threads : 1; + //When n == 1, B marix becomes a vector. + //Reordering is avoided so that LPGEMV can process it efficiently. + if(n == 1) + { + if(rs_b == 1) + { + memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(BLIS_FLOAT))); + }else + { + for(dim_t k0 = 0; k0 < k; k0++) + { + reorder_buf_addr[k0] = input_buf_addr[k0*rs_b]; + } + } + return; + } + #ifdef BLIS_ENABLE_OPENMP _Pragma( "omp parallel num_threads(n_threads)" ) { @@ -162,7 +205,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) // gets multiple of NR columns. dim_t jc_start, jc_end; bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); - for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) { dim_t nc0 = bli_min( ( jc_end - jc ), NC ); @@ -180,7 +222,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) // Compute the total number of iterations we'll need. dim_t n_iter = ( nc0 + NR - 1 ) / NR; - + for ( dim_t pc = 0; pc < k; pc += KC ) { dim_t kc0 = bli_min( ( k - pc ), KC ); diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index 7009cf1e2e..c1c8709367 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -55,6 +55,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16); AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32); AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32); AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16); +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s4s32os32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16s4f32of32); // Performs reordering of input matrix. Reordering is the process of packing // the entire matrix upfront, so that the benefits of packed matrix is obtained @@ -78,10 +80,9 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16); AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32); AOCL_GEMM_REORDER(int8_t,s8s8s32os32); AOCL_GEMM_REORDER(int8_t,s8s8s16os16); +AOCL_GEMM_REORDER(int8_t,u8s4s32os32); +AOCL_GEMM_REORDER(int8_t, bf16s4f32of32); -// Only supports matrices in row major format. This api can perform gemm with -// both normal as well as reordered B matrix as opposesd to sgemm (only -// supports former). This api can be considered analogous to packed sgemm api. #define AOCL_GEMM_MATMUL(A_type,B_type,C_type,Sum_type,LP_SFX) \ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ ( \ @@ -117,4 +118,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8); AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16); AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8); +AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32); +AOCL_GEMM_MATMUL(bfloat16, int8_t, bfloat16, float, bf16s4f32obf16); + #endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h index dbf869fae1..5571bc605c 100644 --- a/addon/aocl_gemm/aocl_gemm_post_ops.h +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,8 @@ #ifndef AOCL_GEMM_POST_OPS_H #define AOCL_GEMM_POST_OPS_H -#define AOCL_MAX_POST_OPS 5 +#define AOCL_MAX_POST_OPS 8 +#define AOCL_MAX_PRE_OPS 1 typedef enum { @@ -44,6 +45,7 @@ typedef enum GELU_TANH = 2, GELU_ERF = 3, CLIP = 4, + SWISH = 5, } AOCL_ELT_ALGO_TYPE; typedef enum @@ -52,6 +54,8 @@ typedef enum ELTWISE = 2, BIAS = 3, SCALE = 4, + MATRIX_ADD = 5, + MATRIX_MUL = 6, } AOCL_POST_OP_TYPE; typedef struct @@ -67,12 +71,15 @@ typedef struct void* scale_factor; void* buff; void* zero_point; + dim_t scale_factor_len; + dim_t zero_point_len; } aocl_post_op_sum; // Also use for scale. typedef struct { bool is_power_of_2; void* scale_factor; + dim_t scale_factor_len; aocl_eltwise_algo algo; } aocl_post_op_eltwise; @@ -83,9 +90,45 @@ typedef struct typedef struct { - aocl_post_op_sum sum; - aocl_post_op_eltwise* eltwise; //Multiple eltwise allowed. - aocl_post_op_bias bias; + void* matrix; + dim_t ldm; +} aocl_post_op_matrix_add; + +typedef struct +{ + void* matrix; + dim_t ldm; +} aocl_post_op_matrix_mul; +typedef struct +{ + void* zero_point; + //len should be one which is one or n i.e., one zp + //per tensor or one zp per channel respectively + dim_t zero_point_len; +} aocl_pre_op_zp; + +typedef struct +{ + void* scale_factor; + //len should be one which is one or n i.e., one sf + //per tensor or one sf per channel respectively + dim_t scale_factor_len; +} aocl_pre_op_sf; + +typedef struct +{ + aocl_pre_op_zp *b_zp; + aocl_pre_op_sf *b_scl; + dim_t seq_length; +} aocl_pre_op; + +typedef struct +{ + aocl_post_op_sum* sum; // Multiple scale/sum allowed. + aocl_post_op_eltwise* eltwise; // Multiple eltwise allowed. + aocl_post_op_bias* bias; + aocl_post_op_matrix_add* matrix_add; + aocl_post_op_matrix_mul* matrix_mul; // eg: seq_length = 2 dim_t seq_length; @@ -93,6 +136,10 @@ typedef struct // eg: seq_vector[0] = BIAS, seq_vector[1] = ELTWISE means bias followed // by eltwise(relu, if AOCL_ELT_ALGO_TYPE = 1). AOCL_POST_OP_TYPE* seq_vector; + + //Pass pre-op structure also through post-ops + aocl_pre_op *pre_ops; + } aocl_post_op; #endif //AOCL_GEMM_POST_OPS_H diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c index e9533536ab..2f73fcf42b 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -78,10 +78,9 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) /* Perform BLAS parameter checking. */ // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + if ( ( blis_transb != BLIS_NO_TRANSPOSE ) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } @@ -91,10 +90,10 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) return; // Only row major supported. } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,6 +103,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Pack is enabled for row major storage when trans A is true. + // Pack tranforms column major matrix to row-major storage as kernel + // expects A matrix to be in row-major format. + if ( bli_is_trans( blis_transa ) ) + { + rs_a = 1; + cs_a = lda; + mtag_a = PACK; + } + // B matrix needs to be packed in a certain format in order to be loaded // and used in VNNI instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -113,8 +122,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) mtag_b = PACK; } - // Only unpacked A supported now. - if (mtag_a != UNPACKED) + // Only unpacked A supported now for row-major A matrix. + if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) ) { bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. @@ -125,7 +134,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c index 2d02416c6c..093616d2ef 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -68,18 +68,33 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16) return 0; // A reorder not supported. } - // Extra space since packing does width in multiples of 16. The vpmaddubsw - // instruction can be used as long as atleast one ymm register can be fully - // loaded; and since k_dim needs to be at least 2, having n_dim atleast 16 - // should give 2x16=32 elements, enough for 1 ymm register.The padding is - // not rounded to NR (=16), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n(n, 16); - // Extra space since packing does length in multiples of 2. - dim_t k_reorder = make_multiple_of_n(k, 2); + dim_t n_reorder; + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + + } - // Extra memory of n_reorder * sizeof( int16_t ) to store sum of every column of B matrix buffer - siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder + ( n_reorder * sizeof( int16_t )); + // Extra space since packing does length in multiples of 4. + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 4 ); + } + + // Extra memory of n_reorder * sizeof( int16_t ) + // to store sum of every column of B matrix buffer + siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder + + ( n_reorder * sizeof( int16_t )); return size_req; } @@ -92,6 +107,16 @@ AOCL_GEMM_REORDER(int8_t,s8s8s16os16) return; // Error. } + trans_t blis_trans; + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + + if( bli_is_trans( blis_trans ) ) + { + bli_print_msg(" Transpose of matrix is not supported in " + "s8s8s16 gemm.", __FILE__, __LINE__ ); + return; // Error. + } // Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it. if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { @@ -114,6 +139,22 @@ AOCL_GEMM_REORDER(int8_t,s8s8s16os16) return; // A reorder not supported. } + if( n == 1 ) + { + int16_t* pack_b_column_sum = ( int16_t* ) ( reorder_buf_addr + + ( sizeof( int8_t ) * n * k )); + + *pack_b_column_sum = 0; + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + reorder_buf_addr[k0] = input_buf_addr[ k0 * ldb ]; + *pack_b_column_sum += reorder_buf_addr[k0]; + } + *pack_b_column_sum *= 128; + return; + } + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c index 8b30c51801..19bbfff7bd 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -78,10 +78,9 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) /* Perform BLAS parameter checking. */ // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + if ( ( blis_transb != BLIS_NO_TRANSPOSE ) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } @@ -91,10 +90,10 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) return; // Only row major supported. } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,6 +103,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Pack is enabled for row major storage when trans A is true. + // Pack tranforms column major matrix to row-major storage as kernel + // expects A matrix to be in row-major format. + if ( bli_is_trans( blis_transa ) ) + { + rs_a = 1; + cs_a = lda; + mtag_a = PACK; + } + // B matrix needs to be packed in a certain format in order to be loaded // and used in VNNI instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -113,8 +122,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) mtag_b = PACK; } - // Only unpacked A supported now. - if (mtag_a != UNPACKED) + // Only unpacked A supported now for row-major A matrix. + if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) ) { bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. @@ -125,7 +134,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c index 413de3f543..747f9155e0 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -60,7 +60,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - + // check for validity of params. AOCL_GEMM_CHECK ( @@ -76,25 +76,36 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( (is_column_major == TRUE) && (post_op_unparsed != NULL) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); - return; // Error. + bli_print_msg("Column major inputs not supported with Post-ops.", + __FILE__, __LINE__); + return; + } + + inc_t rs_a = lda; + inc_t cs_a = 1; + + if (bli_is_trans(blis_transa)) + { + rs_a = 1; + cs_a = lda; } - if ( ( order != 'r' ) && ( order != 'R' ) ) + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) { - bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); - return; // Only row major supported. + rs_b = 1; + cs_b = ldb; } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,20 +115,49 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported " + "in row major case.", __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && + ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices " + "is not supported.", __FILE__, __LINE__); + return; + } + + // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded - // and used in VNNI instrution. As such the mtag_b always needs to be either + // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } - // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) { - bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); - return; // Error. + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; } // Convert post op struct to post op linked list format. @@ -125,7 +165,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; @@ -139,26 +180,52 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_s8s8s32o32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S32 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_s8s8s32o32_openmp_thread_decorator( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } + else + { + lpgemm_s8s8s32o32_openmp_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } #else - lpgemm_s8s8s32o32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S32 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_s8s8s32o32_thread_decorator( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } + else + { + lpgemm_s8s8s32o32_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } #endif } diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c index ef4484aee5..c017eb0c3e 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -73,11 +73,32 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32) // loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16 // should give 4x16=64 elements, enough for 1 zmm register.The padding is // not rounded to NR (=64), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n( n, 16 ); +#ifdef BLIS_KERNELS_ZEN4 + dim_t n_reorder; + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + + } // Extra space since packing does length in multiples of 4. + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 4 ); + } +#else + dim_t n_reorder = make_multiple_of_n( n, 16 ); dim_t k_reorder = make_multiple_of_n( k, 4 ); - +#endif //extra memory of n_reorder * sizeof(int32_t) to store sum of every column of B matrix buffer siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder + ( n_reorder * sizeof( int32_t ) ); @@ -86,12 +107,33 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32) AOCL_GEMM_REORDER(int8_t,s8s8s32os32) { - if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || - ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + trans_t blis_trans; + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + + if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || + (k <= 0) || (n <= 0) || (bli_is_notrans(blis_trans) && (ldb < n)) || + (bli_is_trans(blis_trans) && (ldb < k)) ) { return; // Error. } + inc_t rs_b, cs_b; + if ((order == 'r') || (order == 'R')) + { + rs_b = bli_is_notrans(blis_trans) ? ldb : 1; + cs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + } + else if ((order == 'c') || (order == 'C')) + { + rs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + cs_b = bli_is_notrans(blis_trans) ? ldb : 1; + } + else + { + return; // Error + } + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { @@ -113,7 +155,23 @@ AOCL_GEMM_REORDER(int8_t,s8s8s32os32) { return; // A reorder not supported. } - +#ifdef BLIS_KERNELS_ZEN4 + if( n == 1 ) + { + int32_t* pack_b_column_sum = ( int32_t* ) ( reorder_buf_addr + + ( sizeof( int8_t ) * n * k )); + + *pack_b_column_sum = 0; + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + reorder_buf_addr[k0] = input_buf_addr[ k0 * rs_b ]; + *pack_b_column_sum += reorder_buf_addr[k0]; + } + *pack_b_column_sum *= 128; + return; + } +#endif // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; @@ -129,7 +187,8 @@ AOCL_GEMM_REORDER(int8_t,s8s8s32os32) // Create dummy original b obj; lpgemm_obj_t b; b.storage.aligned_buffer = ( void* )input_buf_addr; - b.rs = ldb; + b.rs = rs_b; + b.cs = cs_b; b.width = n; b.length = k; diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c index 5e7f3ec71c..ffeef5ba15 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -76,48 +76,88 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( (is_column_major == TRUE) && (post_op_unparsed != NULL) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); - return; // Error. + bli_print_msg("Column major inputs not supported with Post-ops.", + __FILE__, __LINE__); + return; } + + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; - if ( ( order != 'r' ) && ( order != 'R' ) ) + if (bli_is_trans(blis_transa)) { - bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); - return; // Only row major supported. + rs_a = 1; + cs_a = lda; } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) + { + rs_b = 1; + cs_b = ldb; + } const inc_t rs_c = ldc; const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; - bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); - bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported in " + " row major case.", __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && + ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices is " + " not supported.", __FILE__, __LINE__); + return; + } + + // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded - // and used in VNNI instrution. As such the mtag_b always needs to be either + // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } - // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) { - bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); - return; // Error. + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; } // Convert post op struct to post op linked list format. @@ -125,7 +165,8 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; @@ -139,26 +180,59 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_s8s8s32o32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S8 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_s8s8s32o32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8 + ); + } + else + { + lpgemm_s8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8 + ); + } #else - lpgemm_s8s8s32o32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S8 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_s8s8s32o32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8); + } + else + { + lpgemm_s8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8 + ); + } #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s4s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s4s32os32_utils.c new file mode 100644 index 0000000000..74f0c0cb65 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s4s32os32_utils.c @@ -0,0 +1,209 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s4s32os32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return 0; //Only row major suppored for int4 reordering. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform int4 reordering.", __FILE__, __LINE__ ); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vnni + // instruction can be used as long as at least one zmm register can be fully + // loaded; and since k_dim needs to be at least 4, having n_dim at least 16 + // should give 4x16=64 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. +#ifdef BLIS_KERNELS_ZEN4 + dim_t n_reorder; + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + } + + // Extra space since packing does length in multiples of 4. + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 4 ); + } +#else + dim_t n_reorder = make_multiple_of_n( n, 16 ); + dim_t k_reorder = make_multiple_of_n( k, 4 ); +#endif + + siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(int8_t,u8s4s32os32) +{ + trans_t blis_trans; + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + + // Transpose not supported for int4 reordering. + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( bli_is_trans( blis_trans ) ) || + ( bli_is_notrans( blis_trans ) && ( ldb < n ) ) ) + { + return; // Error. + } + + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + bli_print_msg(" Only row major int4 matrix reordering supported.", + __FILE__, __LINE__ ); + return; //Only row major suppored for int4 reordering. + } + + inc_t rs_b = ldb; + inc_t cs_b = 1; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform int4 reordering.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + bli_print_msg(" Only int4 B matrix reordering supported.", + __FILE__, __LINE__ ); + return; // A reorder not supported. + } + +#ifdef BLIS_KERNELS_ZEN4 + if( n == 1 ) + { + for ( dim_t ii = 0; ii < k; ++ii ) + { + int8_t lo_val; + dim_t b_inc = ii * rs_b; + // Even index will have data at low 4 bits, and odd at hi 4 bits. + if ( ( b_inc % 2 ) != 0 ) + { + lo_val = ( input_buf_addr[( b_inc / 2 )] >> 4 ) & 0x0F; + } + else + { + lo_val = input_buf_addr[( b_inc / 2 )] & 0x0F; + } + + // Signed scale. + if ( lo_val & 0x08 ) + { + lo_val = lo_val | 0xF0; + } + reorder_buf_addr[ii] = lo_val; + } + return; + } +#endif + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S4S32OS32 ); + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = rs_b; + b.cs = cs_b; + b.width = n; + b.length = k; + + reorderb_nr64_u8s4s32o32( &b, &b_reorder, &rntm_g, lcntx_g ); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index c0614c643b..d6b179f29b 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -78,10 +78,9 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) /* Perform BLAS parameter checking. */ // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + if ( ( blis_transb != BLIS_NO_TRANSPOSE ) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } @@ -91,10 +90,10 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) return; // Only row major supported. } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,6 +103,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Pack is enabled for row major storage when trans A is true. + // Pack tranforms column major matrix to row-major storage as kernel + // expects A matrix to be in row-major format. + if ( bli_is_trans( blis_transa ) ) + { + rs_a = 1; + cs_a = lda; + mtag_a = PACK; + } + // B matrix needs to be packed in a certain format in order to be loaded // and used in VNNI instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -113,8 +122,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) mtag_b = PACK; } - // Only unpacked A supported now. - if (mtag_a != UNPACKED) + // Only unpacked A supported now for row-major A matrix. + if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) ) { bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. @@ -125,7 +134,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c index fd0c64203f..60707f7cc9 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -73,10 +73,26 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16) // loaded; and since k_dim needs to be at least 2, having n_dim at least 16 // should give 2x16=32 elements, enough for 1 ymm register.The padding is // not rounded to NR (=16), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n(n, 16); - // Extra space since packing does length in multiples of 2. - dim_t k_reorder = make_multiple_of_n(k, 2); + dim_t n_reorder; + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + } + + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 2 ); + } siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder; @@ -113,6 +129,23 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16) return; // A reorder not supported. } + if( n == 1 ) + { + if (ldb == 1) + { + memcpy( reorder_buf_addr, input_buf_addr, + ( k * sizeof( int8_t ) ) ); + } + else + { + for( dim_t k0 = 0; k0 < k; k0++ ) + { + reorder_buf_addr[k0] = input_buf_addr[k0 * ldb]; + } + } + return; + } + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c index e8d7b9d146..3c10c75303 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -78,10 +78,9 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) /* Perform BLAS parameter checking. */ // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + if ( ( blis_transb != BLIS_NO_TRANSPOSE ) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } @@ -91,10 +90,10 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) return; // Only row major supported. } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,6 +103,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Pack is enabled for row major storage when trans A is true. + // Pack tranforms column major matrix to row-major storage as kernel + // expects A matrix to be in row-major format. + if ( bli_is_trans( blis_transa ) ) + { + rs_a = 1; + cs_a = lda; + mtag_a = PACK; + } + // B matrix needs to be packed in a certain format in order to be loaded // and used in VNNI instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -113,8 +122,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) mtag_b = PACK; } - // Only unpacked A supported now. - if (mtag_a != UNPACKED) + // Only unpacked A supported now for row-major A matrix. + if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) ) { bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. @@ -125,7 +134,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c index fef861be1e..f29028d57a 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -78,10 +78,9 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) /* Perform BLAS parameter checking. */ // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + if ( ( blis_transb != BLIS_NO_TRANSPOSE ) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } @@ -91,10 +90,10 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) return; // Only row major supported. } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; + inc_t rs_a = lda; + inc_t cs_a = 1; + inc_t rs_b = ldb; + inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -104,6 +103,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + // Pack is enabled for row major storage when trans A is true. + // Pack tranforms column major matrix to row-major storage as kernel + // expects A matrix to be in row-major format. + if ( bli_is_trans( blis_transa ) ) + { + rs_a = 1; + cs_a = lda; + mtag_a = PACK; + } + // B matrix needs to be packed in a certain format in order to be loaded // and used in VNNI instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -113,8 +122,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) mtag_b = PACK; } - // Only unpacked A supported now. - if (mtag_a != UNPACKED) + // Only unpacked A supported now for row-major A matrix. + if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) ) { bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. @@ -125,7 +134,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index d89e6861c3..56c1b06dbe 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -76,48 +76,90 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( (is_column_major == TRUE) && (post_op_unparsed != NULL) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); - return; // Error. + bli_print_msg("Column major inputs not supported with Post-ops.", + __FILE__, __LINE__); + return; + } + + inc_t rs_a = lda; + inc_t cs_a = 1; + + if (bli_is_trans(blis_transa)) + { + rs_a = 1; + cs_a = lda; } - if ( ( order != 'r' ) && ( order != 'R' ) ) + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) { - bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); - return; // Only row major supported. + rs_b = 1; + cs_b = ldb; } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; - bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); - bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported " + "in row major case.", + __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && + ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices " + "is not supported.", + __FILE__, __LINE__); + return; + } + // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded - // and used in VNNI instrution. As such the mtag_b always needs to be either + // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } - // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) { - bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); - return; // Error. + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; } // Convert post op struct to post op linked list format. @@ -125,7 +167,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; @@ -139,26 +182,52 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_u8s8s32o32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S32 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_u8s8s32o32_openmp_thread_decorator( + n, m, k, + (uint8_t *)b, rs_b, cs_b, mtag_b, + (int8_t *)a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } + else + { + lpgemm_u8s8s32o32_openmp_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } #else - lpgemm_u8s8s32o32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S32 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_u8s8s32o32_thread_decorator( + n, m, k, + (uint8_t *)b, rs_b, cs_b, mtag_b, + (int8_t *)a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } + else + { + lpgemm_u8s8s32o32_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S32); + } #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c index b62c294cc6..6992c4d376 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -73,10 +73,31 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) // loaded; and since k_dim needs to be at least 4, having n_dim at least 16 // should give 4x16=64 elements, enough for 1 zmm register.The padding is // not rounded to NR (=64), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n( n, 16 ); +#ifdef BLIS_KERNELS_ZEN4 + dim_t n_reorder; + if( n == 1 ) + { + n_reorder = 1; + } + else + { + n_reorder = make_multiple_of_n( n, 16 ); + } // Extra space since packing does length in multiples of 4. + dim_t k_reorder; + if( n == 1 ) + { + k_reorder = k; + } + else + { + k_reorder = make_multiple_of_n( k, 4 ); + } +#else + dim_t n_reorder = make_multiple_of_n( n, 16 ); dim_t k_reorder = make_multiple_of_n( k, 4 ); +#endif siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder; @@ -85,12 +106,33 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) AOCL_GEMM_REORDER(int8_t,u8s8s32os32) { - if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || - ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + trans_t blis_trans; + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(trans, &blis_trans); + + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( bli_is_notrans( blis_trans ) && ( ldb < n ) ) || + ( bli_is_trans( blis_trans ) && ( ldb < k ) ) ) { return; // Error. } + inc_t rs_b, cs_b; + if ((order == 'r') || (order == 'R')) + { + rs_b = bli_is_notrans(blis_trans) ? ldb : 1; + cs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + } + else if ((order == 'c') || (order == 'C')) + { + rs_b = bli_is_notrans(blis_trans) ? 1 : ldb; + cs_b = bli_is_notrans(blis_trans) ? ldb : 1; + } + else + { + return; // Error + } + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { @@ -113,6 +155,24 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) return; // A reorder not supported. } +#ifdef BLIS_KERNELS_ZEN4 + if( n == 1 ) + { + if (rs_b == 1) + { + memcpy( reorder_buf_addr, input_buf_addr, ( k * sizeof( int8_t ) ) ); + } + else + { + for( dim_t k0 = 0; k0 < k; k0++ ) + { + reorder_buf_addr[k0] = input_buf_addr[k0 * rs_b]; + } + } + return; + } +#endif + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; @@ -128,7 +188,8 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) // Create dummy original b obj; lpgemm_obj_t b; b.storage.aligned_buffer = ( void* )input_buf_addr; - b.rs = ldb; + b.rs = rs_b; + b.cs = cs_b; b.width = n; b.length = k; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c index 6dab94b1fc..13184b5939 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,7 +48,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) trans_t blis_transb; // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + if (bli_cpuid_is_avx512vnni_supported() == FALSE) { bli_print_msg(" AVX512_VNNI ISA not supported by processor, " "cannot perform u8s8s32 gemm.", __FILE__, __LINE__ ); @@ -73,51 +73,93 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) + bool is_row_major = ((order == 'r') || (order == 'R')); + bool is_column_major = ((order == 'c') || (order == 'C')); + + // Column major support disabled for int API's till micro-kernel + // post-ops are updated to account for column major. + if ( (is_column_major == TRUE) && (post_op_unparsed != NULL) ) { - bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); - return; // Error. + bli_print_msg("Column major inputs not supported with Post-ops.", + __FILE__, __LINE__); + return; + } + + inc_t rs_a = lda; + inc_t cs_a = 1; + + if (bli_is_trans(blis_transa)) + { + rs_a = 1; + cs_a = lda; } - if ( ( order != 'r' ) && ( order != 'R' ) ) + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if (bli_is_trans(blis_transb)) { - bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); - return; // Only row major supported. + rs_b = 1; + cs_b = ldb; } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; - bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); - bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // Reorder is not supported for A matrix + if ((is_row_major == TRUE) && (mtag_a == REORDERED)) + { + bli_print_msg(" Reordering of A matrix is not supported " + "in row major case.", + __FILE__, __LINE__); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ((is_column_major == TRUE) && + ((mtag_b == REORDERED) || (mtag_a == REORDERED))) + { + bli_print_msg(" Reordering of column major matrices " + "is not supported.", + __FILE__, __LINE__); + return; + } + // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded - // and used in VNNI instrution. As such the mtag_b always needs to be either + // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ((is_row_major == TRUE) && (mtag_b == UNPACKED)) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (mtag_a == UNPACKED)) + { + mtag_a = PACK; + } - // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if ((is_row_major == TRUE) && (bli_is_trans(blis_transa))) { - bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); - return; // Error. + mtag_a = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb))) + { + mtag_b = PACK; } // Convert post op struct to post op linked list format. @@ -125,7 +167,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order ) + ( void* )c, ( void* )( &order ), + m, n ); if( err != BLIS_SUCCESS ) return; @@ -139,26 +182,53 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_u8s8s32o32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S8 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_u8s8s32o32_openmp_thread_decorator( + n, m, k, + (uint8_t *)b, rs_b, cs_b, mtag_b, + (int8_t *)a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8); + } + else + { + lpgemm_u8s8s32o32_openmp_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8); + } #else - lpgemm_u8s8s32o32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, cs_c, - alpha, beta, - &rntm_g, lcntx_g, - post_op_list, S8 - ); + // Swapping inputs to induce row major computation for column major inputs. + if (is_column_major == TRUE) + { + lpgemm_u8s8s32o32_thread_decorator( + n, m, k, + (uint8_t *)b, rs_b, cs_b, mtag_b, + (int8_t *)a, rs_a, cs_a, mtag_a, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8); + } + else + { + lpgemm_u8s8s32o32_thread_decorator( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + (int32_t *)c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, S8); + } #endif + } diff --git a/addon/aocl_gemm/aocl_util_interface_apis.h b/addon/aocl_gemm/aocl_util_interface_apis.h index d2983b8a64..ffe4843c28 100644 --- a/addon/aocl_gemm/aocl_util_interface_apis.h +++ b/addon/aocl_gemm/aocl_util_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,8 +35,11 @@ #ifndef AOCL_UTIL_INTERFACE_H #define AOCL_UTIL_INTERFACE_H +#include "aocl_gemm_post_ops.h" +#include "aocl_bf16_type.h" + #define AOCL_UTIL_L1_OP(V_type,OP_type) \ -BLIS_EXPORT_ADDON void aocl_ ## OP_type \ +BLIS_EXPORT_ADDON void aocl_gemm_ ## OP_type \ ( \ const dim_t n, \ V_type* x, \ diff --git a/addon/aocl_gemm/aocl_util_l1_ops.c b/addon/aocl_gemm/aocl_util_l1_ops.c index 11a4b83078..4cc702c861 100644 --- a/addon/aocl_gemm/aocl_util_l1_ops.c +++ b/addon/aocl_gemm/aocl_util_l1_ops.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/addon/aocl_gemm/config/lpgemm_blksz_map.h b/addon/aocl_gemm/config/lpgemm_blksz_map.h index 9991a3eb70..b0870f2c5a 100644 --- a/addon/aocl_gemm/config/lpgemm_blksz_map.h +++ b/addon/aocl_gemm/config/lpgemm_blksz_map.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,15 +41,24 @@ #define LPGEMM_BLKSZ_MAP_ZEN4 \ XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ - XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \ + XMACRO(BF16BF16F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \ + XMACRO(BF16S4F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \ XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ - XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ #define LPGEMM_BLKSZ_MAP_ZEN \ - XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \ XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ - XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ + XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \ + +#define LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN4 \ + XMACRO(BF16OF32, 144, 1024, 2048, 6, 64) \ + +#define LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN #endif //LPGEMM_BLKSZ_MAP_H diff --git a/addon/aocl_gemm/config/lpgemm_config.c b/addon/aocl_gemm/config/lpgemm_config.c index ca1020e324..3fee7d2cb2 100644 --- a/addon/aocl_gemm/config/lpgemm_config.c +++ b/addon/aocl_gemm/config/lpgemm_config.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,16 +39,31 @@ #include "lpgemm_kernels.h" #include "lpgemm_pack_bf16.h" #include "lpgemm_packb_s16.h" +#include "lpgemm_packa_s16.h" #include "lpgemm_packa.h" #include "lpgemm_packb.h" #include "lpgemm_packa_s8.h" #include "lpgemm_packb_s8.h" #include "lpgemm_packb_s8s16.h" +#include "lpgemm_pack_f32.h" static lpgemm_cntx_t global_cntx_t_list[AOCL_OPERATION_TYPE_LEN] \ - __attribute__((aligned(64))); //Only one op type supported now. + __attribute__((aligned(64))); //Only one op type supported now. static lpgemm_util_cntx_t global_util_cntx_t_list[AOCL_UTIL_OPERATION_TYPE_LEN] \ - __attribute__((aligned(64))); //Only post-ops like utils. + __attribute__((aligned(64))); //Only post-ops like utils. +static lpgemm_eltwise_ops_cntx_t + global_eltwise_ops_cntx_t_list[AOCL_ELTWISE_OPS_OPERATION_TYPE_LEN] \ + __attribute__((aligned(64))); //Post-ops only utils without gemm. + +// This array is to store function pointers to jit generated kernels. +static void* global_jit_kernels[ LPGEMM_BF16_MR ] + [ ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1 ] + __attribute__((aligned(64))); + +// Buffer size is chosen in order to accommodate the +// worst-case scenario for MR=6 and NR=64. +// The buffersize is chosen using bruteforce method. +#define JIT_KERNEL_SIZE ( 10 * BLIS_PAGE_SIZE ) static bli_pthread_once_t once_check_lpgemm_func_map_init = BLIS_PTHREAD_ONCE_INIT; @@ -58,6 +73,7 @@ static void _lpgemm_util_cntx_init_func_map() global_util_cntx_t_list[F32_GELU_TANH].kern_fun_ptr = NULL; global_util_cntx_t_list[F32_GELU_ERF].kern_fun_ptr = NULL; + global_util_cntx_t_list[F32_SOFTMAX].kern_fun_ptr = NULL; // Kernel dispatch object factory. if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) @@ -82,12 +98,31 @@ static void _lpgemm_util_cntx_init_func_map() #undef UMACRO } +static void _lpgemm_eltwise_ops_cntx_init_func_map() +{ +#define POMACRO(ID,FUNC_PTR) \ + global_eltwise_ops_cntx_t_list[ID].eltwise_ops_kern_fun_ptr = FUNC_PTR; + + global_eltwise_ops_cntx_t_list[BF16OF32].eltwise_ops_kern_fun_ptr = NULL; + + // Kernel dispatch object factory. + if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN4 + LPGEMM_ELTWISE_OPS_KERN_FUNC_MAP_AVX512_VNNI_BF16 +#endif + } + +#undef POMACRO +} + static void _lpgemm_cntx_init_func_map() { #define KMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR; #define PAMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packa_fun_ptr = FUNC_PTR; #define PBMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packb_fun_ptr = FUNC_PTR; - +#define PBSMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].packsclb_fun_ptr = FUNC_PTR; +#define JITMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].jit_kernel = FUNC_PTR; //TODO: Default initialize with reference kernels so that kernel pointer // will be valid even in case none of the zen optimized kernels are // available. This scenario could happen if the addon was built using @@ -97,6 +132,7 @@ static void _lpgemm_cntx_init_func_map() global_cntx_t_list[U8S8S32OS32].kern_fun_ptr = NULL; global_cntx_t_list[F32F32F32OF32].kern_fun_ptr = NULL; global_cntx_t_list[BF16BF16F32OF32].kern_fun_ptr = NULL; + global_cntx_t_list[BF16S4F32OF32].kern_fun_ptr = NULL; // Kernel dispatch object factory. if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) @@ -105,6 +141,37 @@ static void _lpgemm_cntx_init_func_map() LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 + LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16 + +#ifdef LPGEMM_BF16_JIT + lpgemm_jit_inputs_t inputs; + inputs.alpha_scale = TRUE; + inputs.beta_scale = BLIS_BETA_GEN; + + err_t err; + + dim_t num_N_vars = ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1; + + for ( dim_t m = 0; m < LPGEMM_BF16_MR; m++ ) + { + for( dim_t n = 0; n < num_N_vars; n++ ) + { + inputs.MR = ( m == 0 ) ? LPGEMM_BF16_MR : m; + inputs.NR = n * 16; + inputs.m_loop = ( m == 0 ) ? TRUE: FALSE; + inputs.generate_mask = ( n == 0 ) ? TRUE: FALSE; + global_jit_kernels[m][n] = bli_malloc_user( JIT_KERNEL_SIZE, + &err ); + if( global_jit_kernels[m][n] != NULL ) + { + get_jit_kernel( &inputs, + global_jit_kernels[m][n], + JIT_KERNEL_SIZE + ); + } + } + } +#endif #endif } else if ( bli_cpuid_is_avx512vnni_supported() == TRUE ) @@ -138,6 +205,16 @@ static void _lpgemm_cntx_init_func_map() #undef KMACRO } + void lpgemm_set_jit_kernel( void* kernel_fp, dim_t m_index, dim_t n_index ) +{ + global_jit_kernels[m_index][n_index] = kernel_fp; +} + + void* lpgemm_get_jit_kernel( dim_t m_index, dim_t n_index ) +{ + return global_jit_kernels[m_index][n_index]; +} + BLIS_INLINE void lpgemm_set_block_sizes_global_cntx ( AOCL_OPERATION_TYPE op_type, @@ -197,10 +274,51 @@ static void _lpgemm_cntx_init_blksz_map() #undef XMACRO } +BLIS_INLINE void lpgemm_set_block_sizes_global_eltwise_ops_cntx + ( + AOCL_ELTWISE_OPS_OPERATION_TYPE op_type, + dim_t MC, + dim_t NC, + dim_t KC, + dim_t MR, + dim_t NR + ) +{ + global_eltwise_ops_cntx_t_list[op_type].blksz.MC = MC; + global_eltwise_ops_cntx_t_list[op_type].blksz.NC = NC; + global_eltwise_ops_cntx_t_list[op_type].blksz.KC = KC; + global_eltwise_ops_cntx_t_list[op_type].blksz.MR = MR; + global_eltwise_ops_cntx_t_list[op_type].blksz.NR = NR; +} + +static void _lpgemm_eltwise_ops_cntx_init_blksz_map() +{ +#define XMACRO(ID,MC,NC,KC,MR,NR) \ + lpgemm_set_block_sizes_global_eltwise_ops_cntx(ID, MC, NC, KC, MR, NR); + + // Ideally the blocksize needs to be set based on arch id. However + // since this code is also expected to work on other vendor machines, + // the blocksize for a particular version of zen id is generalized + // for all machines that support the ISA supported by that particular + // zen id. + if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) + { + LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN4 + } + else + { + LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN + } + +#undef XMACRO +} + static void lpgemm_cntx_init_map() { _lpgemm_cntx_init_func_map(); _lpgemm_cntx_init_blksz_map(); + _lpgemm_eltwise_ops_cntx_init_blksz_map(); + _lpgemm_eltwise_ops_cntx_init_func_map(); _lpgemm_util_cntx_init_func_map(); } @@ -224,6 +342,12 @@ lpgemm_util_cntx_t* lpgemm_util_get_global_cntx_obj( AOCL_UTIL_OPERATION_TYPE op return &global_util_cntx_t_list[op]; } +lpgemm_eltwise_ops_cntx_t* lpgemm_eltwise_ops_get_global_cntx_obj + ( AOCL_ELTWISE_OPS_OPERATION_TYPE op ) +{ + return &global_eltwise_ops_cntx_t_list[op]; +} + dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) { return global_cntx_t_list[op_type].blksz.MC; diff --git a/addon/aocl_gemm/config/lpgemm_config.h b/addon/aocl_gemm/config/lpgemm_config.h index 87020d0c3d..7645d6951f 100644 --- a/addon/aocl_gemm/config/lpgemm_config.h +++ b/addon/aocl_gemm/config/lpgemm_config.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,9 +37,10 @@ #include "lpgemm_types.h" -// equals to number of ops in enum AOCL_OPERATION_TYPE. -extern lpgemm_cntx_t lpgemm_global_cntx_t_list[AOCL_OPERATION_TYPE_LEN]; -extern lpgemm_cntx_t lpgemm_util_global_cntx_t_list[AOCL_UTIL_OPERATION_TYPE_LEN]; +#define LPGEMM_BF16_MR 6 +#define LPGEMM_BF16_NR 64 +// num_f32_elems_per_zmm = zmm_width / sizeof( float ) +#define NUM_F32_ELEMS_PER_ZMM ( 64 / sizeof(float) ) void aocl_lpgemm_init_global_cntx(); @@ -47,6 +48,9 @@ lpgemm_cntx_t* lpgemm_get_global_cntx_obj( AOCL_OPERATION_TYPE op ); lpgemm_util_cntx_t* lpgemm_util_get_global_cntx_obj( AOCL_UTIL_OPERATION_TYPE op ); +lpgemm_eltwise_ops_cntx_t* lpgemm_eltwise_ops_get_global_cntx_obj + ( AOCL_ELTWISE_OPS_OPERATION_TYPE op ); + dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ); dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ); @@ -61,6 +65,10 @@ void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ); void lpgemm_get_packb_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ); +void lpgemm_set_jit_kernel( void* kernel_fp, dim_t m_index, dim_t n_index ); + +void* lpgemm_get_jit_kernel( dim_t m_index, dim_t n_index ); + void lpgemm_mod_block_size_s16 ( dim_t m, diff --git a/addon/aocl_gemm/config/lpgemm_func_map.h b/addon/aocl_gemm/config/lpgemm_func_map.h index 875a211985..f4d2f2b833 100644 --- a/addon/aocl_gemm/config/lpgemm_func_map.h +++ b/addon/aocl_gemm/config/lpgemm_func_map.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -50,43 +50,62 @@ KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ #define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \ - PAMACRO(U8S8S16OS16, NULL) \ - PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \ + PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \ PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ - PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ - PAMACRO(S8S8S16OS16, NULL) \ + PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \ + PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \ + PAMACRO(S8S8S16OS16, packa_u8s8s16os16) -#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \ - PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ - PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ +#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \ + PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ + PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ - PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ - PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ + PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \ + PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32) + +#define LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16 \ + PBSMACRO(U8S8S16OS16, NULL) \ + PBSMACRO(U8S8S32OS32, NULL) \ + PBSMACRO(BF16BF16F32OF32, NULL) \ + PBSMACRO(S8S8S32OS32, NULL) \ + PBSMACRO(S8S8S16OS16, NULL) \ + PBSMACRO(U8S4S32OS32, NULL) \ + PBSMACRO(BF16S4F32OF32, packsclb_nr64_bf16s4f32of32) \ + +#define LPGEMM_ELTWISE_OPS_KERN_FUNC_MAP_AVX512_VNNI_BF16 \ + POMACRO(BF16OF32, lpgemm_eltwise_ops_kernel_bf16of32_6x64) \ + POMACRO(F32OF32, lpgemm_eltwise_ops_kernel_f32of32_6x64) \ #define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI_BF16 \ UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ -// Icelake + #define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \ KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ #define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \ - PAMACRO(U8S8S16OS16, NULL) \ - PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \ + PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \ PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ - PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ - PAMACRO(S8S8S16OS16, NULL) \ + PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \ + PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \ + PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \ #define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \ PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ @@ -94,62 +113,75 @@ PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \ + PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32) #define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI \ UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ -// Skylake + #define LPGEMM_KERN_FUNC_MAP_AVX512 \ KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ #define LPGEMM_PACKA_FUNC_MAP_AVX512 \ - PAMACRO(U8S8S16OS16, NULL) \ - PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \ + PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \ PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ - PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ - PAMACRO(S8S8S16OS16, NULL) \ + PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \ + PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \ + PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \ #define LPGEMM_PACKB_FUNC_MAP_AVX512 \ PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ - PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ + PBMACRO(BF16BF16F32OF32, NULL) \ PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \ + PBMACRO(BF16S4F32OF32, NULL) \ + PBSMACRO(BF16S4F32OF32, NULL) \ + #define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \ UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ -// Milan, Haswell +// Milan #define LPGEMM_KERN_FUNC_MAP_AVX2 \ KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ KMACRO(U8S8S32OS32, NULL) \ KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \ KMACRO(BF16BF16F32OF32, NULL) \ + KMACRO(BF16S4F32OF32, NULL) \ KMACRO(S8S8S32OS32, NULL) \ KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ #define LPGEMM_PACKA_FUNC_MAP_AVX2 \ - PAMACRO(U8S8S16OS16, NULL) \ + PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \ PAMACRO(U8S8S32OS32, NULL) \ PAMACRO(BF16BF16F32OF32, NULL) \ + KMACRO(BF16S4F32OF32, NULL) \ PAMACRO(S8S8S32OS32, NULL) \ - PAMACRO(S8S8S16OS16, NULL) \ + PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \ #define LPGEMM_PACKB_FUNC_MAP_AVX2 \ PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ PBMACRO(U8S8S32OS32, NULL) \ PBMACRO(BF16BF16F32OF32, NULL) \ + KMACRO(BF16S4F32OF32, NULL) \ PBMACRO(S8S8S32OS32, NULL) \ PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + PBMACRO(U8S4S32OS32, NULL) \ + PBSMACRO(BF16S4F32OF32, NULL) \ #define LPGEMM_UTIL_KERN_FUNC_MAP_AVX2 \ UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx2_kernel) \ diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 5a0201443b..7de91491ba 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -64,9 +64,300 @@ typedef void (*lpgemm_rowvar_bf16) lpgemm_post_op_attr ); +#ifdef BLIS_KERNELS_ZEN4 +LPGEMV(bfloat16, bfloat16, float, bf16bf16f32of32) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + + // Strides are updated based on matrix packing/reordering. + bfloat16* a_use = ( bfloat16* )a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + bfloat16* b_use = ( bfloat16* )b; + inc_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + + + float *c_use = NULL; + bfloat16* pack_a_buffer_bf16; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < F32) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + siz_t mem_a_size_req = 0; + siz_t mem_b_size_req = 0; + + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + bfloat16* pack_b_buffer_bf16; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + if( n == 1 ) + { + // Increased MR from 6 to 16 to make use of 32 ZMM registers + dim_t MR = 16; + + // pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) && ( rs_b != 1 ) ) + { + mem_b_size_req = sizeof( bfloat16 ) * k; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer_bf16 = ( bfloat16* ) bli_mem_buffer( &mem_b ); + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer_bf16[k0] = b[ k0*rs_b ]; + } + + b_use = pack_b_buffer_bf16; + rs_b_use = 1; + cs_b_use = 1; + } + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + const bfloat16 *a_use = a + ic * rs_a; + c_use = c + ic * rs_c; + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = 0; + post_ops_attr.rs_c_downscale = rs_c; + + if( mtag_a == PACK ) + { + mem_a_size_req = sizeof( bfloat16 ) * mc0 * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer_bf16 = ( bfloat16* ) bli_mem_buffer( &mem_a ); + + ( ( pack_bf16 ) lcntx->packa_fun_ptr ) + ( + pack_a_buffer_bf16, + ( a + ( rs_a * ic )), rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_bf16; + } + // Call lpgemv_n_one kernel + lpgemv_n_one_bf16bf16f32of32 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + + // Release pack buffers + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release(rntm, &mem_b); + } + } + else + { + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + thread_jc.n_way = ( thread_jc.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_jc.n_way ); + thread_jc.work_id = thread->tid; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t packb_min_NR = 16; + + dim_t k_updated = k; + k_updated += ( k_updated & 0x1 ); + + dim_t kc0 = bli_min( k, KC ); + + kc0 += ( kc0 & 0x1 ); + + rs_a_use = rs_a; + cs_a_use = 2; + + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( bfloat16 ) * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer_bf16 = + ( bfloat16* ) bli_mem_buffer( &mem_a ); + + ( ( pack_bf16 )lcntx->packa_fun_ptr ) + ( + pack_a_buffer_bf16, + a, rs_a, cs_a, + 1, k, + &rs_a_use, &cs_a_use + ); + + a_use = pack_a_buffer_bf16; + } + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use = c + jc * cs_c; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + + get_B_panel_reordered_start_offset_width( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated); + + b_use = (bfloat16*) ( b + (jc_cur_loop * k_updated ) ); + + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + } + else if( mtag_b == PACK ) + { + + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( bfloat16 ) * nc0_updated * k_updated; + + n_sub_updated = nc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + pack_b_buffer_bf16 = + ( bfloat16* ) bli_mem_buffer( &mem_b ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + dim_t kc0_updated = kc0; + kc0_updated += ( kc0_updated & 0x1 ); + + ( ( pack_bf16 )lcntx->packb_fun_ptr ) + ( + ( ( bfloat16* )pack_b_buffer_bf16 ) + + ( n_sub_updated * pc ), + ( ( ( bfloat16* )b ) + + ( rs_b * pc ) + ( jc * cs_b ) ), + rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use + ); + } + + b_use = pack_b_buffer_bf16; + } + + post_ops_attr.post_op_c_i = 0; + post_ops_attr.post_op_c_j = jc; + post_ops_attr.rs_c_downscale = rs_c; + + lpgemv_m_one_bf16bf16f32of32 + ( + nc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + NR, KC, + n_sub_updated, + jc_cur_loop_rem, + post_op_list, + &post_ops_attr + ); + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } // jc loop + + // Release pack buffers. + if ( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release( rntm, &mem_b ); + } + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + } +} +#endif + + // B should always be packed. LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { + +#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT))) + // Handle using LPGEMV when m or/and n equal to 1 + // The avx512 check will be removed when avx2 kernels added in future + if ( (n == 1) || ( m == 1 ) ) + { + lpgemv_rowvar_bf16bf16f32of32( m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale); + return; + } +#endif + dim_t NC = lcntx->blksz.NC; dim_t KC = lcntx->blksz.KC; dim_t MC = lcntx->blksz.MC; @@ -331,7 +622,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) lpgemm_alloc_mem_panel ( - mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_a, rntm ); diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16_eltwise_ops.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16_eltwise_ops.c new file mode 100644 index 0000000000..a98863f761 --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16_eltwise_ops.c @@ -0,0 +1,107 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_eltwise_ops_interface_apis.h" +#include "lpgemm_eltwise_ops_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// Kernel function prototypes. +typedef void (*lpgemm_util_post_ops_kernel_f32) + ( + const dim_t, + const dim_t, + const bfloat16*, + const dim_t, + const dim_t, + float*, + const dim_t, + const dim_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + +LPGEMM_ELTWISE_OPS_IFACE(bfloat16,float,bf16of32) +{ + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + post_ops_attr.buf_downscale = NULL; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + post_ops_attr.post_op_c_i = ic_start; + post_ops_attr.post_op_c_j = jc_start; + post_ops_attr.rs_c_downscale = rs_b; + post_ops_attr.cs_c_downscale = cs_b; + post_ops_attr.is_first_k = FALSE; + post_ops_attr.is_last_k = TRUE; // Should always be TRUE here. + + // Advance the matrix to the right positions based on thread id. + // To note that float and bfloat16 are both handled using this same + // frame, so the strides needs to be updated on the actual b matrix + // datatype or the c_downscale value. + dim_t dsize = sizeof( float ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + dsize = sizeof( bfloat16 ); + } + + int8_t* b_i = ( int8_t* )b; + + ( ( lpgemm_util_post_ops_kernel_f32 )( lcntx->eltwise_ops_kern_fun_ptr ) ) + ( + ( ic_end - ic_start ), ( jc_end - jc_start ), + a + ( rs_a * ic_start ) + ( cs_a * jc_start ), + rs_a, cs_a, + ( float* )( b_i + ( dsize * ( ( rs_b * ic_start ) + + ( cs_b * jc_start ) ) ) ), rs_b, cs_b, + post_op_list, post_ops_attr + ); +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16s4.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16s4.c new file mode 100644 index 0000000000..79de11b16f --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16s4.c @@ -0,0 +1,452 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_pack_bf16.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// Kernel function prototypes +typedef void (*lpgemm_rowvar_bf16)( + const dim_t, + const dim_t, + const dim_t, + const bfloat16 *, + const dim_t, + const dim_t, + const dim_t, + const bfloat16 *, + const dim_t, + const dim_t, + float *, + const dim_t, + const dim_t, + const float, + const float, + lpgemm_post_op *, + lpgemm_post_op_attr); + +// B should always be packed. +LPGEMM_5LOOP1(bfloat16, int8_t, float, bf16s4f32of32) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; + + const int16_t *a_use = NULL; + dim_t cs_a_use = cs_a; + dim_t rs_a_use = rs_a; + dim_t a_block_stride = 0; + + const bfloat16 *b_use = NULL; + int8_t* b_reorder = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float *c_use_jc = NULL; + float *c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for B. + bfloat16 *pack_b_buffer_bf16; + bfloat16 *pack_a_buffer_bf16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + siz_t mem_a_size_req = 0; + dim_t packb_min_NR = 16; + + // Temporary buffer for C accumulation when downscaling is required. + float *temp_scal_c_buffer_bf16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + // To decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < F32) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + + post_ops_attr.pre_op_scale_factor = pre_op_list->scale_factor; + post_ops_attr.pre_op_scale_factor_len = pre_op_list->scale_factor_len; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + if( mtag_b == PACK_NR ) + { + /* Allocating private pack buffer of size KCxNR for each thread */ + mem_b_size_req = ( KC * NR * sizeof( bfloat16 ) ); + + lpgemm_alloc_mem_panel( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm); + } + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + dim_t nc0_updated = make_multiple_of_n( nc0, 16 ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + /* B should always be reordered */ + { + get_B_panel_reordered_start_offset_width( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated); + + lpgemm_get_packb_strides(lcntx, &rs_b_use, &cs_b_use); + } + + if (c_downscale == F32) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if (c_downscale < F32) + { + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if (k > KC) + { + mem_scale_c_size_req = sizeof(float) * nc0 * (ic_end - ic_start); + + lpgemm_alloc_mem_panel( + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_scale_c, rntm); + + temp_scal_c_buffer_bf16 = bli_mem_buffer(&mem_scale_c); + + c_use_jc = (float *)temp_scal_c_buffer_bf16; + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + for (dim_t pc = 0; pc < k; pc += KC) + { + float beta0 = (pc == 0) ? beta : 1; + dim_t kc0 = bli_min((k - pc), KC); + + // No parallelization in k dim, k always starts at 0. + is_first_k = (pc == 0) ? (TRUE) : (FALSE); + post_ops_attr.is_first_k = is_first_k; + + is_last_k = ((pc + KC) >= k) ? (TRUE) : (FALSE); + post_ops_attr.is_last_k = is_last_k; + + // kc0 needs to be a multiple of 2 so that it can be + // used with dpbf16_ps instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + // B is always supposed to be reordered. + b_reorder = (int8_t*)b + ( ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ) / 2; + + + // B matrix will always be packed. + if ( mtag_b == PACK_KC ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id(&thread_jc); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if (bli_thread_am_ochief(&thread_ic)) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + mem_b_size_req = sizeof(bfloat16) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer(&mem_b); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id]); + + pack_b_buffer_bf16 = + (bfloat16 *)thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end); + + dim_t pre_op_off = jc_cur_loop + jc_cur_loop_rem + + jc_packb_start; + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ((jc_packb_end > jc_packb_start) && + (jc_packb_start < (jc + nc0))) + { + ((pack_s4bf16)lcntx->packsclb_fun_ptr)( + pack_b_buffer_bf16 + (jc_packb_start * kc0_updated), + b_reorder + (jc_packb_start * kc0_updated)/2, + (jc_packb_end - jc_packb_start), kc0, + &rs_b_use, &cs_b_use, + pre_op_list, pre_op_off); + } + else + { + lpgemm_get_packb_strides(lcntx, &rs_b_use, &cs_b_use); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id]); + b_use = pack_b_buffer_bf16; + } + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if (c_downscale < F32) + { + c_use_ic = c_use_jc + (rs_c_use * (ic - ic_start)); + } + else + { + c_use_ic = c_use_jc + (rs_c_use * ic); + } + + if (mtag_a == UNPACKED) + { + a_use = a + (rs_a * ic) + (cs_a * pc); + + // bf16 kernel reads 2 elements, totalling 4 bytes in a + // single broadcast for use in bf16 instruction. + // Non bf16 based kernel requires update to this code. + cs_a_use = 2; + a_block_stride = rs_a; + rs_a_use = rs_a; + } + else if (mtag_a == PACK) + { + + mem_a_size_req = sizeof(bfloat16) * mc0 * kc0; + + lpgemm_alloc_mem_panel( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm); + + pack_a_buffer_bf16 = + (bfloat16 *)bli_mem_buffer(&mem_a); + + ((pack_bf16)lcntx->packa_fun_ptr)( + pack_a_buffer_bf16, + (a + (rs_a * ic) + (cs_a * pc)), rs_a, cs_a, + mc0, kc0, + &rs_a_use, &cs_a_use); + a_use = pack_a_buffer_bf16; + a_block_stride = rs_a_use; + } + + for (dim_t jr = 0; jr < nc0; jr += NR) + { + dim_t nr0 = bli_min((nc0 - jr), NR); + + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = (jc + jr); + post_ops_attr.rs_c_downscale = rs_c_downscale; + + if( mtag_b == PACK_NR ) + { + int8_t* b_jr = b_reorder + ( jr * kc0_updated ) / 2; + dim_t pre_op_off = jc_cur_loop + jc_cur_loop_rem + + jr; + + bfloat16* b_use_jr = bli_mem_buffer(&mem_b); + + /* packing B at JR level */ + ((pack_s4bf16)lcntx->packsclb_fun_ptr)( b_use_jr, b_jr, nr0, kc0, + &rs_b_use, &cs_b_use, + pre_op_list, pre_op_off ); + + /* packed B kernel */ + ((lpgemm_rowvar_bf16)lcntx->kern_fun_ptr)( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + b_use_jr, rs_b_use, cs_b_use, + (c_use_ic + jr), rs_c_use, 1, + alpha, beta0, + post_op_list, post_ops_attr); + } + else if ( mtag_b == PACK_KC) + { + bfloat16* b_use_jr = ( bfloat16* )b_use + ( jr * kc0_updated ); + + /* packed B kernel */ + ((lpgemm_rowvar_bf16)lcntx->kern_fun_ptr)( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + b_use_jr, rs_b_use, cs_b_use, + (c_use_ic + jr), rs_c_use, 1, + alpha, beta0, + post_op_list, post_ops_attr); + } +#ifdef BLIS_KERNELS_ZEN4 + else // mtag_b == UNPACKED + { + int8_t* b_jr = b_reorder + ( jr * kc0_updated ) / 2; + post_ops_attr.pre_op_off = jc_cur_loop + jc_cur_loop_rem + + jr; + + /* bf16s4f32of32 kernel */ + lpgemm_rowvar_bf16s4f32of32_6x64m( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + b_jr, rs_b_use, cs_b_use, + (c_use_ic + jr), rs_c_use, 1, + alpha, beta0, + post_op_list, post_ops_attr ); + } +#endif + } + } + } + /* B is always reordered */ + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } + + // Release pack buffers. + if ( mtag_b == PACK_KC ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_jc), + &thread->comm[bli_thread_work_id(&thread_jc)]); + + if (bli_thread_am_ochief(&thread_ic)) + { + if (bli_mem_is_alloc(&mem_b)) + { + bli_pba_release(rntm, &mem_b); + } + } + } + else if ( mtag_b == PACK_NR ) + { + /* releasing private B buffer */ + if (bli_mem_is_alloc(&mem_b)) + { + bli_pba_release(rntm, &mem_b); + } + } + if (mtag_a == PACK) + { + if (bli_mem_is_alloc(&mem_a)) + { + bli_pba_release(rntm, &mem_a); + } + } + if (c_downscale < F32) + { + if (bli_mem_is_alloc(&mem_scale_c)) + { + bli_pba_release(rntm, &mem_scale_c); + } + } +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c index 99c17b909f..9305b142c4 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -168,3 +168,127 @@ void reorderb_nr64_bf16bf16f32of32 b_reorder->cs = cs_b_reorder; b_reorder->mtag = REORDERED; } + +void reorderb_nr64_bf16s4f32of32( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder, + rntm_t *rntm, + lpgemm_cntx_t *lcntx) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t cs_b = b->cs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + dim_t n_threads = bli_rntm_num_threads(rntm); + n_threads = (n_threads > 0) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma("omp parallel num_threads(n_threads)") + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way(n_threads, &thread_jc); + bli_thrinfo_set_work_id(omp_get_thread_num(), &thread_jc); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way(1, &thread_jc); + bli_thrinfo_set_work_id(0, &thread_jc); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + get_B_panel_reordered_start_offset_width( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + // Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + ((pack_s4)lcntx->packb_fun_ptr)( + ((int8_t *)b_reorder->storage.aligned_buffer) + + ( (jc_cur_loop * k_updated) + (n_sub_updated * pc) + + (jc_cur_loop_rem * kc0_updated) ) / 2, + (((int8_t *)b->storage.aligned_buffer) + + ( (rs_b * pc) + (jc * cs_b) ) / 2), + rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder, NULL); + } + + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h index d9fddedb6e..6595753dc0 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h @@ -45,4 +45,12 @@ void reorderb_nr64_bf16bf16f32of32 lpgemm_cntx_t* lcntx ); +void reorderb_nr64_bf16s4f32of32 + ( + lpgemm_obj_t * b, + lpgemm_obj_t * b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + #endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32_eltwise_ops.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32_eltwise_ops.c new file mode 100644 index 0000000000..38ea9a2343 --- /dev/null +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32_eltwise_ops.c @@ -0,0 +1,102 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_eltwise_ops_interface_apis.h" +#include "lpgemm_eltwise_ops_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// Kernel function prototypes. +typedef void (*lpgemm_util_post_ops_kernel_f32) + ( + const dim_t, + const dim_t, + const float*, + const dim_t, + const dim_t, + float*, + const dim_t, + const dim_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + +LPGEMM_ELTWISE_OPS_IFACE(float,float,f32of32) +{ + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + post_ops_attr.buf_downscale = NULL; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + post_ops_attr.post_op_c_i = ic_start; + post_ops_attr.post_op_c_j = jc_start; + post_ops_attr.rs_c_downscale = rs_b; + post_ops_attr.cs_c_downscale = cs_b; + post_ops_attr.is_first_k = FALSE; + post_ops_attr.is_last_k = TRUE; // Should always be TRUE here. + + // Advance the matrix to the right positions based on thread id. + // To note that float and bfloat16 are both handled using this same + // frame, so the strides needs to be updated on the actual b matrix + // datatype or the c_downscale value. + dim_t dsize = sizeof( float ); + int8_t* b_i = ( int8_t* )b; + + ( ( lpgemm_util_post_ops_kernel_f32 )( lcntx->eltwise_ops_kern_fun_ptr ) ) + ( + ( ic_end - ic_start ), ( jc_end - jc_start ), + a + ( rs_a * ic_start ) + ( cs_a * jc_start ), + rs_a, cs_a, + ( float* )( b_i + ( dsize * ( ( rs_b * ic_start ) + + ( cs_b * jc_start ) ) ) ), rs_b, cs_b, + post_op_list, post_ops_attr + ); +} diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 61e8cf8654..c94a8d80d2 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +38,7 @@ #include "lpgemm_utils.h" #include "lpgemm_thrinfo_utils.h" #include "lpgemm_kernels.h" +#include "lpgemm_pack_f32.h" // Kernel function prototypes typedef void (*lpgemm_rowvar_f32) @@ -87,8 +88,286 @@ void lpgemm_pack_b_f32f32f32of32 cntx_t* cntx ); -LPGEMM_5LOOP(float,float,float,f32f32f32of32) +#ifdef BLIS_KERNELS_ZEN4 +LPGEMV(float, float, float, f32f32f32of32) { + cntx_t *cntx = bli_gks_query_cntx(); + num_t dt = BLIS_FLOAT; + + const float* a_use = (float*)a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + float* b_use = (float*)b; + inc_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + inc_t ps_b_use; + + siz_t mem_a_size_req = 0; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + float* pack_a_buffer_f32f32f32of32; + float* pack_b_buffer_f32f32f32of32; + + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NR, cntx); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NC, cntx); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx); + + // Strides are updated based on matrix packing/reordering. + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < F32) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + if(n == 1) + { + float* pack_b_buffer_f32f32f32of32; + //TODO: AVX2 support need to be added + // Increased MR from 6 to 16 to make use of 32 ZMM registers + dim_t MR = 16; + // Pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) && ( rs_b != 1 ) ) + { + mem_b_size_req = sizeof( float ) * k; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer_f32f32f32of32 = ( float* ) bli_mem_buffer( &mem_b ); + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer_f32f32f32of32[k0] = b[ k0*rs_b ]; + } + + b_use = pack_b_buffer_f32f32f32of32; + rs_b_use = 1; + cs_b_use = 1; + } + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + a_use = a + ic * rs_a; + c_use = c + ic * rs_c; + post_ops_attr.post_op_c_i = ic; + + if( mtag_a == PACK && cs_a != 1 ) + { + mem_a_size_req = sizeof(float) * mc0 * k; + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a ); + + packa_mr16_f32f32f32of32_col_major + ( + pack_a_buffer_f32f32f32of32, + a_use, rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_f32f32f32of32; + } + + // Call lpgemv_n_one kernel + lpgemv_n_one_f32f32f32of32 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + if ( ( mtag_a == PACK ) && ( bli_mem_is_alloc( &mem_a ) ) ) + { + bli_pba_release( rntm, &mem_a ); + } + if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) ) + { + bli_pba_release( rntm, &mem_b ); + } + } + else + { + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + thread_jc.n_way = ( thread_jc.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_jc.n_way ); + thread_jc.work_id = thread->tid; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( float ) * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer_f32f32f32of32 = + ( float* ) bli_mem_buffer( &mem_a ); + + packa_mr16_f32f32f32of32_col_major + ( + pack_a_buffer_f32f32f32of32, + a_use, rs_a, cs_a, + 1, k, + &rs_a_use, &cs_a_use + ); + + a_use = pack_a_buffer_f32f32f32of32; + } + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use = c + jc * cs_c; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated); + + b_use = (float*) ( b + (jc_cur_loop * k) ); + + rs_b_use = NR; + cs_b_use = 1; + } + else if (mtag_b == PACK) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, NR ); + + mem_b_size_req = sizeof( float ) * nc0_updated * k; + n_sub_updated = nc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + pack_b_buffer_f32f32f32of32 = + ( float* ) bli_mem_buffer( &mem_b ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // Set the strides for pack buffer. + rs_b_use = NR; + cs_b_use = 1; + ps_b_use = kc0; + + lpgemm_pack_b_f32f32f32of32 + ( + ( b + ( rs_b * pc ) + ( cs_b * jc ) ), + pack_b_buffer_f32f32f32of32 + ( n_sub_updated * pc ), + nc0 , kc0, + rs_b, cs_b, ( NR * ps_b_use ), NR, + cntx + ); + } + b_use = pack_b_buffer_f32f32f32of32; + } + else + { + b_use = (float*) b + jc * cs_b; + } + + //update post-op pointer + post_ops_attr.post_op_c_j = jc; + + // Call kernel + lpgemv_m_one_f32f32f32of32 + ( + nc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + NR, KC, + n_sub_updated, + jc_cur_loop_rem, + post_op_list, + &post_ops_attr + ); + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } // jc loop + + // Release pack buffers. + if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) ) + { + bli_pba_release( rntm, &mem_b ); + } + } +} +#endif + +LPGEMM_5LOOP(float, float, float, f32f32f32of32) +{ +#ifdef BLIS_KERNELS_ZEN4 + // Handle using LPGEMV when m or/and n equal to 1 + // The avx512 check will be removed when avx2 kernels added in future + if ( ( ( m == 1 ) || ( n == 1 ) ) && (bli_cpuid_is_avx512_supported() == TRUE) ) + { + lpgemv_rowvar_f32f32f32of32(m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale); + return; + } +#endif // Query the global cntx. cntx_t* cntx = bli_gks_query_cntx(); @@ -101,8 +380,6 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); - /*ToDo: Based on context kernel 6x64m or 6x16m will be picked here */ - // Strides are updated based on matrix packing/reordering. const float* a_use = NULL; dim_t rs_a_use = rs_a; @@ -133,8 +410,8 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) mem_t mem_a = BLIS_MEM_INITIALIZER; siz_t mem_a_size_req = 0; - // Check if packing of A is required. - bool should_pack_B = bli_rntm_pack_b( rntm ); + // Check if packing of B is required. + bool should_pack_B = bli_rntm_pack_b( rntm ) || ( rs_b == 1 ); // Pack buffer for B. float* pack_b_buffer_f32f32f32of32; @@ -150,7 +427,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - post_ops_attr.c_stor_type = c_downscale; + post_ops_attr.c_stor_type = c_downscale; if ( c_downscale < F32 ) { post_ops_attr.buf_downscale = c; @@ -333,7 +610,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) lpgemm_pack_a_f32f32f32of32 ( - ( a + ( rs_a * ic ) + pc ), + ( a + ( rs_a * ic ) + ( pc * cs_a) ), pack_a_buffer_f32f32f32of32, mc0, kc0, rs_a, cs_a, ps_a_use, MR, @@ -344,7 +621,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) } else { - a_use = a + ( rs_a * ic ) + pc; + a_use = a + ( rs_a * ic ) + ( pc * cs_a ); ps_a_use = MR * rs_a; } @@ -382,7 +659,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) } // Release pack buffers. - if ( mtag_b == PACK ) + if ( ( mtag_b == PACK ) && ( should_pack_B == TRUE ) ) { // All threads in work group should wait till B matrix usage is // completed by the participating threads. @@ -438,7 +715,7 @@ void lpgemm_pack_a_f32f32f32of32 float* p_temp = reorder_buf_addr_a; dim_t ir, it; - // Iterate over every logical micropanel in the source matrix. + // Iterate over every logical micropanel in the source mmatrix. for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 ) { dim_t panel_dim_i = bli_min( MR, m - ir ); diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index a0920edaf3..c28da2c9f9 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -50,9 +50,9 @@ void lpgemm_rowvar_ ## LP_SFX \ const dim_t cs_a, \ const AOCL_MEMORY_TAG mtag_a, \ const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - const AOCL_MEMORY_TAG mtag_b, \ + dim_t rs_b, \ + dim_t cs_b, \ + AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ @@ -71,4 +71,65 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32); LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32); LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32); LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16); + +#define LPGEMM_5LOOP1(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t* rntm, \ + lpgemm_thrinfo_t* thread, \ + lpgemm_cntx_t* lcntx, \ + lpgemm_pre_op* pre_op_list, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ + +LPGEMM_5LOOP1(bfloat16,int8_t,float,bf16s4f32of32); + +#define LPGEMV(A_type, B_type, C_type, LP_SFX) \ +void lpgemv_rowvar_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type *a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type *b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type *c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t *rntm, \ + lpgemm_thrinfo_t *thread, \ + lpgemm_cntx_t *lcntx, \ + lpgemm_post_op *post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ + +LPGEMV(float, float, float, f32f32f32of32); +LPGEMV(bfloat16,bfloat16,float,bf16bf16f32of32); +LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32); +LPGEMV(int8_t,int8_t,int32_t,s8s8s32os32); + #endif // LPGEMM_5LOOP_INTF_H diff --git a/addon/aocl_gemm/frame/lpgemm_eltwise_ops_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_eltwise_ops_interface_apis.h new file mode 100644 index 0000000000..9337fd05c0 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_eltwise_ops_interface_apis.h @@ -0,0 +1,63 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_POSTOP_INTF_H +#define LPGEMM_POSTOP_INTF_H + +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" + +#define LPGEMM_ELTWISE_OPS_IFACE(A_type,B_type,LP_SFX) \ +void lpgemm_eltwise_ops_interface_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + rntm_t* rntm, \ + lpgemm_thrinfo_t* thread, \ + lpgemm_eltwise_ops_cntx_t* lcntx, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ + +LPGEMM_ELTWISE_OPS_IFACE(bfloat16,float,bf16of32); +LPGEMM_ELTWISE_OPS_IFACE(float,float,f32of32); + +#endif //LPGEMM_POSTOP_INTF_H diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index 92f5849c20..f6f7cdd0f4 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,22 +35,96 @@ #include "blis.h" #include "lpgemm_post_ops.h" -BLIS_INLINE void lpgemm_set_node_params +BLIS_INLINE void lpgemm_set_pre_ops_node_params ( - lpgemm_post_op* post_op_node, - LPGEMM_POST_OP_CODE op_code, - void* op1, - void* op2, - void* op3, + lpgemm_pre_op* pre_op_node, + void* zero_point, void* scale_factor, - bool is_power_of_2 + dim_t zero_point_len, + dim_t scale_factor_len ) +{ + pre_op_node->scale_factor = scale_factor; + pre_op_node->scale_factor_len = scale_factor_len; + pre_op_node->zp = zero_point; + pre_op_node->zp_len = zero_point_len; + pre_op_node->next = NULL; +} + +err_t lpgemm_translate_to_pre_ops_list( + aocl_pre_op *pre_op_unparsed, + lpgemm_pre_op *pre_op_list, + dim_t m, + dim_t n, + dim_t k) +{ + (void)(m); // Unused for now, potential to be used later. + (void)(k); // Unused for now, potential to be used later. + + if ((pre_op_unparsed == NULL) || (pre_op_unparsed->seq_length <= 0)) + { + lpgemm_set_pre_ops_node_params + ( + pre_op_list, + NULL, NULL, 0, 0 + ); + + return BLIS_SUCCESS; + } + + if ((pre_op_unparsed->seq_length > AOCL_MAX_POST_OPS)) + { + lpgemm_set_pre_ops_node_params + ( + pre_op_list, + NULL, NULL, 0, 0 + ); + + bli_print_msg(" Max supported pre-ops is 2, supplied input pre-ops" + " are more. Exiting..", + __FILE__, __LINE__); + return BLIS_UNEXPECTED_VECTOR_DIM; // Error, seq length exceeds max pre ops permitted. + } + + for (dim_t i = 0; i < pre_op_unparsed->seq_length; ++i) + { + if (pre_op_unparsed->b_zp != NULL && pre_op_unparsed->b_scl!=NULL) + { + lpgemm_set_pre_ops_node_params + ( + pre_op_list, + (pre_op_unparsed->b_zp)->zero_point, + (pre_op_unparsed->b_scl)->scale_factor, + (pre_op_unparsed->b_zp)->zero_point_len, + (pre_op_unparsed->b_scl)->scale_factor_len + ); + } + + // Simulating linked link using an array. + if (i < (pre_op_unparsed->seq_length - 1)) + { + (pre_op_list + i)->next = (pre_op_list + i + 1); + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE void lpgemm_set_node_params( + lpgemm_post_op *post_op_node, + LPGEMM_POST_OP_CODE op_code, + void *op1, + void *op2, + void *op3, + void *scale_factor, + dim_t scale_factor_len, + bool is_power_of_2) { post_op_node->op_code = op_code; post_op_node->op_args1 = op1; post_op_node->op_args2 = op2; post_op_node->op_args3 = op3; post_op_node->scale_factor = scale_factor; + post_op_node->scale_factor_len = scale_factor_len; post_op_node->is_power_of_2 = is_power_of_2; post_op_node->next = NULL; } @@ -60,16 +134,22 @@ err_t lpgemm_translate_to_post_ops_list aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, void* scale_buffer, - void* meta_arg + void* meta_arg, + dim_t m, + dim_t n ) { - if ( post_op_unparsed == NULL ) + ( void )( scale_buffer ); //Unused for now, potential to be used later. + ( void )( m ); //Unused for now, potential to be used later. + + if ( ( post_op_unparsed == NULL ) || ( post_op_unparsed->seq_length <= 0 ) ) { lpgemm_set_node_params ( post_op_list, POST_OPS_DISABLE, - NULL, NULL, NULL, NULL, FALSE + NULL, NULL, NULL, NULL, 0, FALSE ); + return BLIS_SUCCESS; } @@ -78,27 +158,39 @@ err_t lpgemm_translate_to_post_ops_list lpgemm_set_node_params ( post_op_list, POST_OPS_DISABLE, - NULL, NULL, NULL, NULL, FALSE + NULL, NULL, NULL, NULL, 0, FALSE ); - return BLIS_SUCCESS; //Error, seq length exceeds max post ops permitted. + + bli_print_msg(" Max supported post-ops is 5, supplied input post-ops" \ + " are more. Exiting..", __FILE__, __LINE__ ); + return BLIS_UNEXPECTED_VECTOR_DIM; //Error, seq length exceeds max post ops permitted. } - dim_t e_i = 0; //Multiple eltwise supported. + dim_t e_i = 0; // Multiple eltwise supported. + dim_t s_i = 0; // Multiple sum/scale supported. + dim_t b_i = 0; // Multiple bias supported. + dim_t m_i = 0; // Multiple matrix add supported. + dim_t mul_i = 0; // Multiple matrix mul supported. for ( dim_t i = 0; i < post_op_unparsed->seq_length; ++i ) { // Dispatcher code switch ( *( post_op_unparsed->seq_vector + i ) ) { case SUM: - lpgemm_set_node_params - ( - ( post_op_list + i ), POST_OPS_SUM, - post_op_unparsed->sum.buff, - post_op_unparsed->sum.zero_point, - NULL, - post_op_unparsed->sum.scale_factor, - post_op_unparsed->sum.is_power_of_2 - ); + { + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_SUM, + ( post_op_unparsed->sum + s_i )->buff, + ( post_op_unparsed->sum + s_i )->zero_point, + NULL, + ( post_op_unparsed->sum + s_i )->scale_factor, + ( post_op_unparsed->sum + s_i )->scale_factor_len, + ( post_op_unparsed->sum + s_i )->is_power_of_2 + ); + + s_i += 1; + } break; case ELTWISE: { @@ -132,6 +224,14 @@ err_t lpgemm_translate_to_post_ops_list } tmp_code = POST_OPS_CLIP; break; + case SWISH: + if( ( post_op_unparsed->eltwise + e_i )->algo.alpha == NULL ) + { + bli_print_msg(" Post_op.alpha is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + tmp_code = POST_OPS_SWISH; + break; default: break; } @@ -142,39 +242,116 @@ err_t lpgemm_translate_to_post_ops_list ( post_op_unparsed->eltwise + e_i )->algo.alpha, ( post_op_unparsed->eltwise + e_i )->algo.beta, ( post_op_unparsed->eltwise + e_i )->scale_factor, + ( post_op_unparsed->eltwise + e_i )->scale_factor_len, ( post_op_unparsed->eltwise + e_i )->is_power_of_2 ); e_i += 1; } break; case BIAS: - if( post_op_unparsed->bias.bias == NULL ) { - bli_print_msg(" Post_op.bias is NULL. Exiting..", __FILE__, __LINE__ ); - return BLIS_NULL_POINTER; + if( ( post_op_unparsed->bias + b_i )->bias == NULL ) + { + bli_print_msg(" Post_op.bias is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_BIAS, + ( post_op_unparsed->bias + b_i )->bias, + meta_arg, NULL, NULL, 0, FALSE + ); + + b_i += 1; } - lpgemm_set_node_params - ( - ( post_op_list + i ), POST_OPS_BIAS, - post_op_unparsed->bias.bias, - meta_arg, NULL, NULL, FALSE - ); break; case SCALE: - if( ( post_op_unparsed->sum.scale_factor == NULL ) || - ( post_op_unparsed->sum.zero_point == NULL ) ) { - bli_print_msg(" Post_op.scale scale_factor or zero_point is NULL. Exiting..", __FILE__, __LINE__ ); - return BLIS_NULL_POINTER; + if ( ( ( post_op_unparsed->sum + s_i )->scale_factor_len > 0 ) && + ( ( post_op_unparsed->sum + s_i )->scale_factor == NULL ) ) + { + bli_print_msg(" Post_op.scale scale_factor is NULL. Exiting..", + __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + if ( ( ( post_op_unparsed->sum + s_i )->zero_point_len > 0 ) && + ( ( post_op_unparsed->sum + s_i )->zero_point == NULL ) ) + { + bli_print_msg(" Post_op.scale zero_point is NULL. Exiting..", + __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + if ( ( ( post_op_unparsed->sum + s_i )->scale_factor_len != 1 ) && + ( ( post_op_unparsed->sum + s_i )->scale_factor_len < n ) ) + { + bli_print_msg(" Post_op.scale scale factor length is < n." \ + " Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + if ( ( ( post_op_unparsed->sum + s_i )->zero_point_len != 1 ) && + ( ( post_op_unparsed->sum + s_i )->zero_point_len < n ) ) + { + bli_print_msg(" Post_op.scale zero point length is < n." \ + " Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_DOWNSCALE, + ( post_op_unparsed->sum + s_i )->zero_point, + meta_arg, &( ( post_op_unparsed->sum + s_i )->zero_point_len ), + ( post_op_unparsed->sum + s_i )->scale_factor, + ( post_op_unparsed->sum + s_i )->scale_factor_len, + FALSE + ); + + s_i += 1; + } + break; + case MATRIX_ADD: + { + if ( ( ( post_op_unparsed->matrix_add + m_i )->matrix == NULL ) || + ( ( post_op_unparsed->matrix_add + m_i )->ldm <= 0 ) ) + { + bli_print_msg(" Post_op.matrix_add attributes are invalid. Exiting..", + __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_MATRIX_ADD, + ( post_op_unparsed->matrix_add + m_i )->matrix, + meta_arg, &( ( post_op_unparsed->matrix_add + m_i )->ldm ), + NULL, 0, FALSE + ); + + m_i += 1; } - lpgemm_set_node_params - ( - ( post_op_list + i ), POST_OPS_DOWNSCALE, - post_op_unparsed->sum.zero_point, - meta_arg, scale_buffer, - post_op_unparsed->sum.scale_factor, FALSE - ); break; + case MATRIX_MUL: + { + if ( ( ( post_op_unparsed->matrix_mul + mul_i )->matrix == NULL ) || + ( ( post_op_unparsed->matrix_mul + mul_i )->ldm <= 0 ) ) + { + bli_print_msg(" Post_op.matrix_add attributes are invalid. Exiting..", + __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } + + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_MATRIX_MUL, + ( post_op_unparsed->matrix_mul + mul_i )->matrix, + meta_arg, &( ( post_op_unparsed->matrix_mul + mul_i )->ldm ), + NULL, 0, FALSE + ); + + mul_i += 1; + } + break; default: break; } diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index ed1d3ed86b..7565ef293b 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,45 +45,76 @@ typedef enum POST_OPS_GELU_ERF = 5, POST_OPS_CLIP = 6, POST_OPS_DOWNSCALE = 7, - POST_OPS_SUM = 8, + POST_OPS_MATRIX_ADD = 8, + POST_OPS_SWISH = 9, + POST_OPS_MATRIX_MUL = 10, + POST_OPS_SUM = 11, + } LPGEMM_POST_OP_CODE; // Used as an internal structure. typedef struct lpgemm_post_op_t { - LPGEMM_POST_OP_CODE op_code; - void* op_args1; - void* op_args2; // alpha, zero_point, storage order - void* op_args3; // beta, downscale buffer/original C matrix + uint64_t op_code; + void* op_args1; // zero_point, bias, sum_buff + void* op_args2; // alpha, storage order, sum_zero_point + void* op_args3; // beta, zero_point_len void* scale_factor; + dim_t scale_factor_len; bool is_power_of_2; struct lpgemm_post_op_t* next; } lpgemm_post_op; +// Used as an internal structure. +typedef struct lpgemm_pre_op_t +{ + uint64_t op_code; + void *scale_factor; + dim_t scale_factor_len; + void *zp; + dim_t zp_len; + struct lpgemm_pre_op_t *next; +} lpgemm_pre_op; + // Used as an internal structure. typedef struct lpgemm_post_op_attr_t { - dim_t post_op_c_i; - dim_t post_op_c_j; - dim_t rs_c_downscale; - dim_t cs_c_downscale; + uint64_t post_op_c_i; + uint64_t post_op_c_j; + uint64_t rs_c_downscale; + uint64_t cs_c_downscale; void* buf_downscale; - bool is_first_k; - bool is_last_k; - AOCL_STORAGE_TYPE c_stor_type; - dim_t b_sum_offset; + uint64_t is_first_k; + uint64_t is_last_k; + uint64_t c_stor_type; + uint64_t b_sum_offset; int32_t* b_col_sum_vec; int16_t* b_col_sum_vec_s16; + void* pre_op_scale_factor; + dim_t pre_op_scale_factor_len; + dim_t pre_op_off; } lpgemm_post_op_attr; + err_t lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, void* scale_buffer, - void* meta_arg + void* meta_arg, + dim_t m, + dim_t n ); +err_t lpgemm_translate_to_pre_ops_list + ( + aocl_pre_op *pre_op_unparsed, + lpgemm_pre_op *pre_op_list, + dim_t m, + dim_t n, + dim_t k + ); + #define POST_OP_LABEL_LASTK_SAFE_JUMP \ if ( ( post_ops_attr.is_last_k == TRUE ) && ( post_ops_list_temp != NULL ) ) \ { \ diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h index 28f210a067..b0c69079b3 100644 --- a/addon/aocl_gemm/frame/lpgemm_types.h +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,9 +37,9 @@ typedef enum { - INT8 = 0, - INT16 = 1, - INT32 = 2 + LPGEMM_INT8 = 0, + LPGEMM_INT16 = 1, + LPGEMM_INT32 = 2 } AOCL_ARRAY_TYPE; // Enum to denote the storage data type (output matrix). @@ -63,14 +63,16 @@ typedef enum // Enum name template:A_mat_type ## B_mat_type ## Accumulate_type ## C_mat_type. typedef enum { - U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C - U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C - F32F32F32OF32 = 2, // float - A, float - B, float - C + U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C + U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C + F32F32F32OF32 = 2, // float - A, float - B, float - C BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C - S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C - S8S8S16OS16 = 5 // int8_t - A, int8_t - B, int16_t - C + S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C + S8S8S16OS16 = 5, // int8_t - A, int8_t - B, int16_t - C + U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix. + BF16S4F32OF32 = 7 // Only used for reordering int4_t B matrix. } AOCL_OPERATION_TYPE; -#define AOCL_OPERATION_TYPE_LEN 6 +#define AOCL_OPERATION_TYPE_LEN 8 typedef enum { @@ -80,11 +82,20 @@ typedef enum } AOCL_UTIL_OPERATION_TYPE; #define AOCL_UTIL_OPERATION_TYPE_LEN 3 +typedef enum +{ + BF16OF32 = 0, + F32OF32 = 1 +} AOCL_ELTWISE_OPS_OPERATION_TYPE; +#define AOCL_ELTWISE_OPS_OPERATION_TYPE_LEN 2 + typedef enum { UNPACKED = 0, PACK = 1, - REORDERED = 2, + PACK_KC = 2, + PACK_NR = 3, + REORDERED = 4, } AOCL_MEMORY_TAG; typedef enum @@ -143,9 +154,16 @@ typedef struct void_fp kern_fun_ptr; void_fp packa_fun_ptr; void_fp packb_fun_ptr; + void_fp packsclb_fun_ptr; lpgemm_pack_strides_t pack_s; } lpgemm_cntx_t; +typedef struct +{ + lpgemm_block_size_t blksz; + void_fp eltwise_ops_kern_fun_ptr; +} lpgemm_eltwise_ops_cntx_t; + typedef struct { void_fp kern_fun_ptr; diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c index 474014d5df..bef75e7315 100644 --- a/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,7 +71,7 @@ void aocl_reorderb_nr32_s8s8s16o16 // To access the last row of B matrix - Column sum of B matrix int16_t* pack_b_column_sum = ( int16_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated )); - for (int idx = 0; idx < n_updated; idx++ ) + for (dim_t idx = 0; idx < n_updated; idx++ ) { *( pack_b_column_sum + idx ) = 0; } @@ -169,16 +169,6 @@ void aocl_reorderb_nr32_s8s8s16o16 adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); } } - // for (int i =0; i< k_updated; i++) - // { - // for (int j=0; j< n_updated; j++) - // { - // printf(" %d ", *( int8_t* )(b->storage.aligned_buffer + i*n_updated + j )); - // } - // printf(" \n "); - // } - // for (int i =0; i< n_updated; i++) - // printf(" %d ", *(pack_b_column_sum + i)); // Changing the packed matrix properties in the packed matrix object b_reorder->rs = rs_b_reorder; diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c index 974ff4f3eb..1eec1c56c2 100644 --- a/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,6 +39,7 @@ #include "lpgemm_utils_s8.h" #include "lpgemm_config.h" #include "lpgemm_thrinfo_utils.h" +#include "lpgemm_packa_s16.h" // Kernel function prototypes typedef void (*lpgemm_rowvar_s16_s8) @@ -62,6 +63,149 @@ typedef void (*lpgemm_rowvar_s16_s8) lpgemm_post_op_attr ); + + +LPGEMV(int8_t,int8_t,int16_t,s8s8s16os16) +{ + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + + // Strides are updated based on matrix packing/reordering. + int8_t* a_use = ( int8_t* )a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + int8_t* b_use = ( int8_t* )b; + inc_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + + int16_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < S16) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + siz_t mem_a_size_req = 0; + siz_t mem_b_size_req = 0; + + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + int8_t* pack_a_buffer; + int8_t* pack_b_buffer; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Increased MR from 6 to 8 to make use of 16 ymm regs + dim_t MR = 8; + + // Pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) ) + { + mem_b_size_req = sizeof( int8_t ) * k + sizeof( int16_t ); + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b ); + + int16_t* pack_b_column_sum = ( int16_t* ) ( pack_b_buffer + + ( sizeof( int8_t ) * k )); + + *pack_b_column_sum = 0; + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer[k0] = b[ k0*rs_b ]; + *pack_b_column_sum += pack_b_buffer[k0]; + } + *pack_b_column_sum *= 128; + post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum; + + b_use = pack_b_buffer; + rs_b_use = 1; + cs_b_use = 1; + } + else if ( mtag_b == REORDERED ) + { + post_ops_attr.b_col_sum_vec_s16 = ( int16_t* ) ( b + k ); + } + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + a_use = (int8_t*)a + ic * rs_a; + + c_use = c + ic * rs_c; + + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = 0; + post_ops_attr.rs_c_downscale = rs_c; + + if( mtag_a == PACK ) + { + mem_a_size_req = sizeof( int8_t ) * mc0 * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer = ( int8_t* ) bli_mem_buffer( &mem_a ); + + ( ( packa_s16 ) lcntx->packa_fun_ptr ) + ( + ( uint8_t* )pack_a_buffer, + ( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer; + } + + // Call lpgemv_n_one kernel + lpgemv_n_one_s8s8s16os16 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + + // Release pack buffers + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release(rntm, &mem_b); + } +} + + // B should always be packed. LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) { @@ -79,10 +223,29 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) return; } + + if( n == 1 ) + { + lpgemv_rowvar_s8s8s16os16( m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale ); + return; + } + + const int8_t *b_use; const int8_t *a_use; dim_t rs_a_use = rs_a; dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; dim_t rs_b_use = rs_b; dim_t cs_b_use = cs_b; @@ -92,6 +255,11 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) dim_t rs_c_use = rs_c; dim_t rs_c_downscale = rs_c; + // Pack buffer for A. + int8_t* pack_a_buffer_s8s8s16o16; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + // Pack buffer for B. int8_t *pack_b_buffer_s8s8s16o16; mem_t mem_b = BLIS_MEM_INITIALIZER; @@ -280,7 +448,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) ( pack_b_buffer_s8s8s16o16 + (jc_packb_start * kc0_updated), - pack_b_column_sum + ( cs_b * jc_packb_start ), + pack_b_column_sum + ( cs_b * jc_packb_start ), (b + (rs_b * pc) + (cs_b * jc) + (cs_b * jc_packb_start)), rs_b, @@ -339,10 +507,48 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) c_use_ic = c_use_jc + ( rs_c_use * ic ); } - a_use = a + (rs_a * ic) + (cs_a * pc); - cs_a_use = 1; + // Matrix A packed and reordered code path is not triggerred + // currently for row-major inputs since we do not support it yet. + // Pack is enabled for column-major inputs to transform into + // row-major inputs as kernel expects row storage format. + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_s8s8s16o16 = ( int8_t* )bli_mem_buffer( &mem_a ); + + ( ( packa_s16 )lcntx->packa_fun_ptr ) + ( + ( uint8_t* )pack_a_buffer_s8s8s16o16, + ( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_s8s8s16o16; + + if( cs_a == 1 ) + { + a_block_stride = kc0_updated; + } + + else + { + a_block_stride = rs_a_use; + } - dim_t a_block_stride = rs_a; + } + + else + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + cs_a_use = 1; + a_block_stride = rs_a; + } post_ops_attr.b_sum_offset = 0; diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c index ece6c48762..fcf5ec622c 100644 --- a/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c @@ -50,8 +50,9 @@ void reorderb_nr64_s8s8s32o32 dim_t NC = lcntx->blksz.NC; dim_t KC = lcntx->blksz.KC; dim_t NR = lcntx->blksz.NR; - + dim_t rs_b = b->rs; + dim_t cs_b = b->cs; dim_t rs_b_reorder; dim_t cs_b_reorder; @@ -68,7 +69,10 @@ void reorderb_nr64_s8s8s32o32 dim_t n_threads = bli_rntm_num_threads( rntm ); n_threads = ( n_threads > 0 ) ? n_threads : 1; - int32_t* pack_b_column_sum = ( int32_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated )); + int32_t* pack_b_column_sum = + ( int32_t* ) ( b_reorder->storage.aligned_buffer + + ( sizeof( int8_t ) * n_updated * k_updated )); + for ( dim_t idx = 0; idx < n_updated; idx++ ) { *( pack_b_column_sum + idx ) = 0; @@ -159,8 +163,8 @@ void reorderb_nr64_s8s8s32o32 ( jc_cur_loop_rem * kc0_updated ) ), pack_b_column_sum + jc, ( ( ( int8_t* )b->storage.aligned_buffer ) + - ( rs_b * pc ) + jc ), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ( rs_b * pc ) + jc * cs_b), + rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); } adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c index 21fa102fd4..24307b89ad 100644 --- a/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,7 @@ #include "lpgemm_utils_s8.h" #include "lpgemm_thrinfo_utils.h" #include "lpgemm_config.h" +#include "lpgemm_packa.h" // Kernel function prototypes typedef void (*lpgemm_rowvar_s32_s8) @@ -63,6 +64,300 @@ typedef void (*lpgemm_rowvar_s32_s8) lpgemm_post_op_attr ); +#ifdef BLIS_KERNELS_ZEN4 + +LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + + // Strides are updated based on matrix packing/reordering. + int8_t* a_use = ( int8_t* )a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + int8_t* b_use = ( int8_t* )b; + dim_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + + int32_t *c_use = NULL; + + int32_t* pack_b_column_sum = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < S32) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + siz_t mem_a_size_req = 0; + siz_t mem_b_size_req = 0; + + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + int8_t* pack_b_buffer_s8s8s32os32; + int8_t* pack_a_buffer_s8s8s32os32; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + if( n == 1 ) + { + // Increased MR from 6 to 16 to make use of 32 ZMM registers + dim_t MR = 16; + + // pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) ) + { + mem_b_size_req = sizeof( int8_t ) * k + sizeof( int32_t ); + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer_s8s8s32os32 = ( int8_t* ) bli_mem_buffer( &mem_b ); + + int32_t* pack_b_column_sum = ( int32_t* ) ( pack_b_buffer_s8s8s32os32 + + ( sizeof( int8_t ) * k )); + + *pack_b_column_sum = 0; + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer_s8s8s32os32[k0] = b[ k0*rs_b ]; + *pack_b_column_sum += pack_b_buffer_s8s8s32os32[k0]; + } + *pack_b_column_sum *= 128; + post_ops_attr.b_col_sum_vec = pack_b_column_sum; + + b_use = pack_b_buffer_s8s8s32os32; + rs_b_use = 1; + cs_b_use = 1; + } + else if( mtag_b == REORDERED ) + { + post_ops_attr.b_col_sum_vec = ( int32_t* )( b + k ); + } + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + const int8_t *a_use = a + ic * rs_a; + c_use = c + ic * rs_c; + + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = 0; + post_ops_attr.rs_c_downscale = rs_c; + + if( mtag_a == PACK ) + { + mem_a_size_req = sizeof( int8_t ) * mc0 * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer_s8s8s32os32 = (int8_t*)bli_mem_buffer( &mem_a ); + + ( ( packa_s32 ) lcntx->packa_fun_ptr ) + ( + ( uint8_t* ) pack_a_buffer_s8s8s32os32, + ( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_s8s8s32os32; + } + // Call lpgemv_n_one kernel + lpgemv_n_one_s8s8s32os32 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + + // Release pack buffers + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release(rntm, &mem_b); + } + } + else + { + dim_t jc_start, jc_end; + thread_jc.n_way = ( thread_jc.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_jc.n_way ); + thread_jc.work_id = thread->tid; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR(); + + dim_t k_updated = make_multiple_of_n( k, 4 ); + dim_t n_updated = make_multiple_of_n( n, 16 ); + + rs_a_use = rs_a; + cs_a_use = 4; + + + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer_s8s8s32os32 = + ( int8_t* ) bli_mem_buffer( &mem_a ); + + ( ( packa_s32 )lcntx->packa_fun_ptr ) + ( + ( uint8_t* )pack_a_buffer_s8s8s32os32, + ( uint8_t* )a, rs_a, cs_a, + 1, k, + &rs_a_use, &cs_a_use + ); + + a_use = pack_a_buffer_s8s8s32os32; + } + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated ); + + b_use = (int8_t*) ( b + (jc_cur_loop * k_updated ) ); + + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + + post_ops_attr.b_col_sum_vec = ( (int32_t*)( b + + ( k_updated * n_updated ) ) ) + + jc; + } + else if( mtag_b == PACK ) + { + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + + mem_b_size_req = sizeof( int8_t ) * nc0_updated * k_updated + + ( nc0_updated * sizeof( int32_t ) ); + + n_sub_updated = nc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + pack_b_buffer_s8s8s32os32 = + ( int8_t* ) bli_mem_buffer( &mem_b ); + + + pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32os32 + + ( sizeof( int8_t ) * nc0_updated + * k_updated ) ); + + for (dim_t idx = 0; idx < nc0; idx++ ) + { + *( pack_b_column_sum + idx ) = 0; + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + ( ( packb_s32_s8 )lcntx->packb_fun_ptr ) + ( + ( pack_b_buffer_s8s8s32os32 ) + + ( n_sub_updated * pc ), + pack_b_column_sum, + ( b + ( rs_b * pc ) + (jc * cs_b)), + rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use + ); + } + + b_use = pack_b_buffer_s8s8s32os32; + post_ops_attr.b_col_sum_vec = pack_b_column_sum; + } + + post_ops_attr.post_op_c_i = 0; + post_ops_attr.post_op_c_j = jc; + post_ops_attr.rs_c_downscale = rs_c; + post_ops_attr.b_sum_offset = 0; + + lpgemv_m_one_s8s8s32os32 + ( + nc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + NR, KC, + n_sub_updated, + jc_cur_loop_rem, + post_op_list, + &post_ops_attr + ); + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } // jc loop + + // Release pack buffers. + if ( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release( rntm, &mem_b ); + } + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + } +} + +#endif // B should always be packed. LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) { @@ -78,6 +373,26 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) return; } +#ifdef BLIS_KERNELS_ZEN4 + + if( ( m == 1 ) || ( n == 1 ) ) + { + lpgemv_rowvar_s8s8s32o32( m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale ); + return; + } + +#endif + // Strides are updated based on matrix packing/reordering. const int8_t* a_use = NULL; dim_t rs_a_use = rs_a; @@ -232,7 +547,8 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) // which is a multiple of 16. Subsequently the nc0 offsets used // for packed/reordered buffers needs to be updated.pack - mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated + ( nc0_updated * sizeof( int32_t ) ); + mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated + + ( nc0_updated * sizeof( int32_t ) ); lpgemm_alloc_mem_panel ( @@ -267,7 +583,9 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) if ( pc == 0) { - pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32o32 + ( sizeof( int8_t ) * nc0_updated * kc0_updated ) ); + pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32o32 + + ( sizeof( int8_t ) * nc0_updated + * kc0_updated ) ); } // Ensure thread ranges are valid, especially cases where no: @@ -289,7 +607,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) pack_b_buffer_s8s8s32o32 + ( jc_packb_start * kc0_updated ), pack_b_column_sum + ( cs_b * jc_packb_start ), ( b + ( rs_b * pc ) + ( cs_b * jc ) + - ( cs_b * jc_packb_start ) ), rs_b, + ( cs_b * jc_packb_start ) ), rs_b, cs_b, ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); @@ -349,7 +667,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) // currently since we do not support it yet. if ( mtag_a == PACK ) { - mem_a_size_req = sizeof( int8_t ) * mc0 * kc0_updated; + mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; lpgemm_alloc_mem_panel ( @@ -358,17 +676,25 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) ); pack_a_buffer_s8s8s32o32 = ( int8_t* )bli_mem_buffer( &mem_a ); - ( ( packa_s32_s8 )lcntx->packa_fun_ptr ) + ( ( packa_s32 )lcntx->packa_fun_ptr ) ( - pack_a_buffer_s8s8s32o32, - ( a + ( rs_a * ic ) + pc ), rs_a, + ( uint8_t* )pack_a_buffer_s8s8s32o32, + ( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a, mc0, kc0, &rs_a_use, &cs_a_use ); a_use = pack_a_buffer_s8s8s32o32; - a_block_stride = kc0_updated; - } + if( cs_a == 1 ) + { + a_block_stride = kc0_updated; + } + + else + { + a_block_stride = rs_a_use; + } + } else { a_use = a + ( rs_a * ic ) + ( cs_a * pc ); diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index ad0e7f10d5..50e42e6c4e 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -37,6 +37,7 @@ #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_types.h" #include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_eltwise_ops_interface_apis.h" #ifdef BLIS_ENABLE_OPENMP @@ -298,18 +299,16 @@ BLIS_INLINE void lpgemm_s16o16_get_threading dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type ); dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type ); - dim_t mr_blks = ( m + MR - 1 ) / MR; - dim_t nr_blks = ( n + NR - 1 ) / NR; if ( n <= NR ) { - ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); + ( *ic_ways ) = ( *n_threads ); ( *jc_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else if ( m <= MR ) { - ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *jc_ways ) = ( *n_threads ); ( *ic_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } @@ -317,26 +316,6 @@ BLIS_INLINE void lpgemm_s16o16_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) - { - ( *ic_ways ) = mr_blks; - ( *jc_ways ) = nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( mr_blks < ( *ic_ways ) ) - { - ( *ic_ways ) = mr_blks; - dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); - ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( nr_blks < ( *jc_ways ) ) - { - ( *jc_ways ) = nr_blks; - dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); - ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } } } else @@ -424,13 +403,13 @@ BLIS_INLINE void lpgemm_s32o32_get_threading if ( n <= NR ) { - ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); + ( *ic_ways ) = ( *n_threads ); ( *jc_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else if ( m <= MR ) { - ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *jc_ways ) = ( *n_threads ); ( *ic_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } @@ -438,27 +417,7 @@ BLIS_INLINE void lpgemm_s32o32_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) - { - ( *ic_ways ) = mr_blks; - ( *jc_ways ) = nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( mr_blks < ( *ic_ways ) ) - { - ( *ic_ways ) = mr_blks; - dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); - ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( nr_blks < ( *jc_ways ) ) - { - ( *jc_ways ) = nr_blks; - dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); - ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else + if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) ) { lpgemm_pnl_wrk_heur_adjust_ic_jc_ways ( @@ -552,13 +511,13 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading if ( n <= NR ) { - ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); + ( *ic_ways ) = ( *n_threads ); ( *jc_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else if ( m <= MR ) { - ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *jc_ways ) = ( *n_threads ); ( *ic_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } @@ -566,27 +525,7 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) - { - ( *ic_ways ) = mr_blks; - ( *jc_ways ) = nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( mr_blks < ( *ic_ways ) ) - { - ( *ic_ways ) = mr_blks; - dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); - ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( nr_blks < ( *jc_ways ) ) - { - ( *jc_ways ) = nr_blks; - dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); - ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else + if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) ) { lpgemm_pnl_wrk_heur_adjust_ic_jc_ways ( @@ -658,13 +597,13 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading if ( n <= NR ) { - ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); + ( *ic_ways ) = ( *n_threads ); ( *jc_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else if ( m <= MR ) { - ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *jc_ways ) = ( *n_threads ); ( *ic_ways ) = 1; ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } @@ -672,27 +611,7 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) - { - ( *ic_ways ) = mr_blks; - ( *jc_ways ) = nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( mr_blks < ( *ic_ways ) ) - { - ( *ic_ways ) = mr_blks; - dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); - ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else if ( nr_blks < ( *jc_ways ) ) - { - ( *jc_ways ) = nr_blks; - dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); - ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; - ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); - } - else + if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) ) { lpgemm_adjust_ic_jc_ways ( @@ -721,8 +640,8 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) ) { - if (((k <= page_size_b_floatx2 ) && ( m_ic > MT_2 ) && ( n_jc >= NT ) ) || - ((bli_cpuid_is_avx512_supported() == FALSE ) && (k > page_size_b_floatx2))) + if (((k <= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) || + ((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2))) { bli_rntm_set_pack_b( 1, rntm_g ); bli_rntm_set_pack_a( 1, rntm_g ); @@ -830,6 +749,305 @@ GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16) +#define GEN_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ + lpgemm_pre_op* pre_op_list, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ +{ \ + dim_t n_threads; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways; \ + dim_t jc_ways; \ + \ + lpgemm_bf16bf16f32of32_get_threading \ + ( \ + &n_threads, \ + &ic_ways, &jc_ways, \ + m, n, k, rntm_g \ + ); \ + \ + /* Decide whether to go with pack-based implementation + or kernel-level implementation */ \ + dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); \ + if( ( m / ic_ways ) > MC ) mtag_b = PACK_KC; \ + else mtag_b = UNPACKED; \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_pba_rntm_set_pba( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \ + thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \ + err_t bli_errors = BLIS_SUCCESS; \ + \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ), &bli_errors ); \ + } \ + for ( dim_t i = 0; i < jc_ways; ++i ) \ + { \ + bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \ + } \ + \ + _Pragma( "omp parallel num_threads(n_threads)" ) \ + { \ + /* Create a thread-local copy of the master thread's rntm_t. This is + * necessary since we want each thread to be able to track its own + * small block pool_t as it executes down the function stack.*/ \ + rntm_t rntm_l = *rntm_g; \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = omp_get_thread_num(); \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comms; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, cs_c,\ + alpha, \ + beta, \ + &rntm_l, \ + &thread, \ + lcntx, \ + pre_op_list, \ + post_op_list, c_downscale \ + ); \ + } \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + bli_free_intl( cur_lpgemm_comms ); \ + } \ +} \ + +GEN_LPGEMM_OPENMP_DECORATOR_MP(bfloat16, int8_t, float, bf16s4f32of32) + +BLIS_INLINE void lpgemm_eltwise_ops_bf16of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + rntm_t* rntm_g, + lpgemm_eltwise_ops_cntx_t* lcntx + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; + dim_t mr_blks = ( m + MR - 1 ) / MR; + dim_t nr_blks = ( n + NR - 1 ) / NR; + + if ( n <= NR ) + { + ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); + ( *jc_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( m <= MR ) + { + ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *ic_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks >= ( *n_threads ) ) + { + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else if ( mr_blks >= ( dim_t )( ( 3.0 / 4.0 ) * ( *n_threads ) ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) + { + ( *ic_ways ) = mr_blks; + ( *jc_ways ) = nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks < ( *ic_ways ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( nr_blks < ( *jc_ways ) ) + { + ( *jc_ways ) = nr_blks; + dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); + ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + ( *n_threads ) = 1; + ( *jc_ways ) = 1; + ( *ic_ways ) = 1; + } +} + +BLIS_INLINE void lpgemm_eltwise_ops_f32of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + rntm_t* rntm_g, + lpgemm_eltwise_ops_cntx_t* lcntx + ) +{ + lpgemm_eltwise_ops_bf16of32_get_threading + ( + n_threads, + ic_ways, jc_ways, + m, n, rntm_g, + lcntx + ); +} + +#define GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR(A_type,B_type,LPGEMM_SFX) \ +void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + rntm_t* rntm_g, \ + lpgemm_eltwise_ops_cntx_t* lcntx, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ +{ \ + dim_t n_threads; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways; \ + dim_t jc_ways; \ + \ + lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _get_threading \ + ( \ + &n_threads, \ + &ic_ways, &jc_ways, \ + m, n, rntm_g, lcntx \ + ); \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_pba_rntm_set_pba( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \ + thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \ + err_t bli_errors = BLIS_SUCCESS; \ + \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ), &bli_errors ); \ + } \ + for ( dim_t i = 0; i < jc_ways; ++i ) \ + { \ + bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \ + } \ + \ + _Pragma( "omp parallel num_threads(n_threads)" ) \ + { \ + /* Create a thread-local copy of the master thread's rntm_t. This is + * necessary since we want each thread to be able to track its own + * small block pool_t as it executes down the function stack.*/ \ + rntm_t rntm_l = *rntm_g; \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = omp_get_thread_num(); \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comms; \ + \ + lpgemm_eltwise_ops_interface_ ## LPGEMM_SFX \ + ( \ + m, n, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + &rntm_l, \ + &thread, \ + lcntx, \ + post_op_list, c_downscale \ + ); \ + } \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + bli_free_intl( cur_lpgemm_comms ); \ + } \ +} \ + +GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR(bfloat16,float,bf16of32) +GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR(float,float,f32of32) + #else #define GEN_LPGEMM_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ @@ -905,4 +1123,132 @@ GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) GEN_LPGEMM_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32) GEN_LPGEMM_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16) +#define GEN_LPGEMM_DECORATOR1(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ + lpgemm_pre_op* pre_op_list, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ +{ \ + dim_t n_threads = 1; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways = 1; \ + dim_t jc_ways = 1; \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_pba_rntm_set_pba( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comm; \ + thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \ + \ + bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = 0; \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comm; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, cs_c, \ + alpha, \ + beta, \ + rntm_g, \ + &thread, \ + lcntx, \ + pre_op_list, \ + post_op_list, c_downscale \ + ); \ +} + +GEN_LPGEMM_DECORATOR1(bfloat16, int8_t, float, bf16s4f32of32) + +#define GEN_UTIL_ELTWISE_OPS_DECORATOR(A_type,B_type,LPGEMM_SFX) \ +void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + rntm_t* rntm_g, \ + lpgemm_eltwise_ops_cntx_t* lcntx, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ +{ \ + dim_t n_threads = 1; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways = 1; \ + dim_t jc_ways = 1; \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_pba_rntm_set_pba( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comm; \ + thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \ + \ + bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = 0; \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comm; \ + \ + lpgemm_eltwise_ops_interface_ ## LPGEMM_SFX \ + ( \ + m, n, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + rntm_g, \ + &thread, \ + lcntx, \ + post_op_list, c_downscale \ + ); \ +} \ + +GEN_UTIL_ELTWISE_OPS_DECORATOR(bfloat16,float,bf16of32) +GEN_UTIL_ELTWISE_OPS_DECORATOR(float,float,f32of32) + #endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 4fd0a12bff..6c18973d06 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -73,6 +73,55 @@ GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16) + +#define GEN_LPGEMM_OPENMP_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ + lpgemm_pre_op* pre_op_list, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ); \ + +GEN_LPGEMM_OPENMP_DECORATOR_FN1(bfloat16, int8_t, float, bf16s4f32of32) + +#define GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR_FN(A_type,B_type,LPGEMM_SFX) \ +void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + rntm_t* rntm_g, \ + lpgemm_eltwise_ops_cntx_t* lcntx, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ); \ + +GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR_FN(bfloat16,float,bf16of32) +GEN_UTIL_ELTWISE_OPS_OPENMP_DECORATOR_FN(float,float,f32of32) + #else #define GEN_LPGEMM_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ @@ -107,6 +156,54 @@ GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32) GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16) +#define GEN_LPGEMM_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ + lpgemm_pre_op* pre_op_list, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ); \ + +GEN_LPGEMM_DECORATOR_FN1(bfloat16, int8_t, float, bf16s4f32of32) + +#define GEN_UTIL_ELTWISE_OPS_DECORATOR_FN(A_type,B_type,LPGEMM_SFX) \ +void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + rntm_t* rntm_g, \ + lpgemm_eltwise_ops_cntx_t* lcntx, \ + lpgemm_post_op* post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ); \ + +GEN_UTIL_ELTWISE_OPS_DECORATOR_FN(bfloat16,float,bf16of32) +GEN_UTIL_ELTWISE_OPS_DECORATOR_FN(float,float,f32of32) + #endif #endif //LPGEMM_THREAD_DECOR_OPENMP_H diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 5e4740a952..65d3081dd7 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,6 +35,7 @@ #include "blis.h" #include "lpgemm_5loop_interface_apis.h" #include "lpgemm_packb_s16.h" +#include "lpgemm_packa_s16.h" #include "lpgemm_kernels.h" #include "lpgemm_utils.h" #include "lpgemm_config.h" @@ -62,6 +63,137 @@ typedef void (*lpgemm_rowvar_s16) lpgemm_post_op_attr ); + + +LPGEMV(uint8_t,int8_t,int16_t,u8s8s16os16) +{ + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + + // Strides are updated based on matrix packing/reordering. + uint8_t* a_use = ( uint8_t* )a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + int8_t* b_use = ( int8_t* )b; + inc_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + + int16_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < S16) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + siz_t mem_a_size_req = 0; + siz_t mem_b_size_req = 0; + + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + uint8_t* pack_a_buffer; + int8_t* pack_b_buffer; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Increased MR from 6 to 8 to make use of 16 ymm regs + dim_t MR = 8; + + // Pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) && ( rs_b != 1 ) ) + { + mem_b_size_req = sizeof( int8_t ) * k; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b ); + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer[k0] = b[ k0*rs_b ]; + } + + b_use = pack_b_buffer; + rs_b_use = 1; + cs_b_use = 1; + } + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + a_use = (uint8_t*)a + ic * rs_a; + + c_use = c + ic * rs_c; + + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = 0; + post_ops_attr.rs_c_downscale = rs_c; + + if( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer = ( uint8_t* ) bli_mem_buffer( &mem_a ); + + ( ( packa_s16 ) lcntx->packa_fun_ptr ) + ( + pack_a_buffer, + ( a + ( rs_a * ic )), rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer; + } + + // Call lpgemv_n_one kernel + lpgemv_n_one_u8s8s16os16 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + + // Release pack buffers + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release(rntm, &mem_b); + } +} + + // B should always be packed. LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) { @@ -79,10 +211,27 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) return; } + if( n == 1 ) + { + lpgemv_rowvar_u8s8s16os16( m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale ); + return; + } + const int8_t *b_use; const uint8_t *a_use; dim_t rs_a_use = rs_a; dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; dim_t rs_b_use = rs_b; dim_t cs_b_use = cs_b; @@ -92,6 +241,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) dim_t rs_c_use = rs_c; dim_t rs_c_downscale = rs_c; + // Pack buffer for A. + uint8_t* pack_a_buffer_u8s8s16o16; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + // Pack buffer for B. int8_t *pack_b_buffer_u8s8s16o16; mem_t mem_b = BLIS_MEM_INITIALIZER; @@ -315,10 +469,53 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) c_use_ic = c_use_jc + ( rs_c_use * ic ); } - a_use = a + (rs_a * ic) + (cs_a * pc); - cs_a_use = 1; + // Matrix A packed and reordered code path is not triggerred + // currently for row-major inputs since we do not support it yet. + // Pack is enabled for column-major inputs to transform into + // row-major inputs as kernel expects row storage format. + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_u8s8s16o16 = ( uint8_t* )bli_mem_buffer( &mem_a ); - dim_t a_block_stride = rs_a; + ( ( packa_s16 )lcntx->packa_fun_ptr ) + ( + pack_a_buffer_u8s8s16o16, + ( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_u8s8s16o16; + + if( cs_a == 1 ) + { + a_block_stride = kc0_updated; + } + + else + { + a_block_stride = rs_a_use; + } + + } + else if ( mtag_a == REORDERED ) + { + lpgemm_get_packa_strides( lcntx, &rs_a_use, &cs_a_use ); + a_use = a + ( pc * m ) + ( kc0_updated * ic ); + a_block_stride = kc0_updated; + } + else + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + cs_a_use = 1; + a_block_stride = rs_a; + } for (dim_t jr = 0; jr < nc0; jr += NR) { diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c index 14dff21af4..e1fba65be4 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,6 +52,7 @@ void reorderb_nr64_u8s8s32o32 dim_t NR = lcntx->blksz.NR; dim_t rs_b = b->rs; + dim_t cs_b = b->cs; dim_t rs_b_reorder; dim_t cs_b_reorder; @@ -145,17 +146,14 @@ void reorderb_nr64_u8s8s32o32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) - ( ( packb_s32 )lcntx->packb_fun_ptr ) - ( - ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + - ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + - ( jc_cur_loop_rem * kc0_updated ) ), - ( ( ( int8_t* )b->storage.aligned_buffer ) + - ( rs_b * pc ) + jc ), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder - ); + ((packb_s32)lcntx->packb_fun_ptr)( + (((int8_t *)b_reorder->storage.aligned_buffer) + + (jc_cur_loop * k_updated) + (n_sub_updated * pc) + + (jc_cur_loop_rem * kc0_updated)), + (((int8_t *)b->storage.aligned_buffer) + + (rs_b * pc) + jc * cs_b), + rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder); } - adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); } } @@ -177,6 +175,7 @@ void reordera_mr6_u8s8s32o32 dim_t KC = lcntx->blksz.KC; dim_t rs_a = a->rs; + dim_t cs_a = a->cs; dim_t rs_a_reorder; dim_t cs_a_reorder; @@ -202,7 +201,7 @@ void reordera_mr6_u8s8s32o32 ( ( ( uint8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + ( ic * kc0_updated ) ), ( ( ( uint8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), - rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder + rs_a, cs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder ); } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h index 58a5255637..a9a6a9b0ca 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h @@ -53,4 +53,12 @@ void reordera_mr6_u8s8s32o32 lpgemm_cntx_t* lcntx ); +void reorderb_nr64_u8s4s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + #endif //LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_s4_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_s4_reorder.c new file mode 100644 index 0000000000..9c03b3f9b9 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_s4_reorder.c @@ -0,0 +1,173 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" +#include "lpgemm_packb.h" +#include "lpgemm_config.h" + +void reorderb_nr64_u8s4s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; + + if ( ( ( KC % 2 ) != 0 ) || ( ( NC % 2 ) != 0 ) || ( ( NR % 2 ) != 0 ) ) + { + bli_print_msg(" Only even KC, NC, and NR supported for int4 B" + " matrix reordering.", + __FILE__, __LINE__ ); + return; // Odd KC, NC, NR not supported. + } + + dim_t rs_b = b->rs; + dim_t cs_b = b->cs; + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t n = b->width; + dim_t k = b->length; + + // k needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + + dim_t n_threads = bli_rntm_num_threads( rntm ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, get_packb_u8s8s32o32_min_NR(), + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + // The int4 input buffer increment needs to be halved to + // account for the byte level traversal. + ( ( packb_s32 )lcntx->packb_fun_ptr )( + ( ( ( int8_t * )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( (int8_t * )b->storage.aligned_buffer ) + + ( ( ( rs_b * pc ) + ( jc * cs_b ) ) / 2 ) ), + rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); + } + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 29239803d6..3651576826 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,6 +63,260 @@ typedef void (*lpgemm_rowvar_s32) lpgemm_post_op_attr ); +#ifdef BLIS_KERNELS_ZEN4 + +LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + + // Strides are updated based on matrix packing/reordering. + uint8_t* a_use = ( uint8_t* )a; + inc_t rs_a_use = rs_a; + inc_t cs_a_use = cs_a; + + int8_t* b_use = ( int8_t* )b; + inc_t rs_b_use = rs_b; + inc_t cs_b_use = cs_b; + + int32_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < S32) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + siz_t mem_a_size_req = 0; + siz_t mem_b_size_req = 0; + + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + uint8_t* pack_a_buffer; + int8_t* pack_b_buffer; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + if( n == 1 ) + { + // Increased MR from 6 to 16 to make use of 32 ZMM registers + dim_t MR = 16; + + // Pack B matrix if rs_b > 1 + if( ( mtag_b == PACK ) && ( rs_b != 1 ) ) + { + mem_b_size_req = sizeof( int8_t ) * k; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_b, rntm + ); + + pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b ); + + for( dim_t k0 = 0; k0 < k; k0++ ) + { + pack_b_buffer[k0] = b[ k0*rs_b ]; + } + + b_use = pack_b_buffer; + rs_b_use = 1; + cs_b_use = 1; + + } + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + thread_ic.n_way = ( thread_ic.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_ic.n_way ); + thread_ic.work_id = thread->tid; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + const uint8_t *a_use = a + ic * rs_a; + c_use = c + ic * rs_c; + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = 0; + post_ops_attr.rs_c_downscale = rs_c; + + if( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer = ( uint8_t* ) bli_mem_buffer( &mem_a ); + + ( ( packa_s32 ) lcntx->packa_fun_ptr ) + ( + pack_a_buffer, + ( a + ( rs_a * ic )), rs_a, cs_a, + mc0, k, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer; + } + // Call lpgemv_n_one kernel + lpgemv_n_one_u8s8s32os32 + ( + mc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + + // Release pack buffers + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release(rntm, &mem_b); + } + } + else + { + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + thread_jc.n_way = ( thread_jc.n_way == 1 ) ? + ( thread->n_threads ) : ( thread_jc.n_way ); + thread_jc.work_id = thread->tid; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t packb_min_NR = get_packb_u8s8s32o32_min_NR(); + + dim_t k_updated = make_multiple_of_n( k, 4 ); + + rs_a_use = rs_a; + cs_a_use = 4; + + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * k; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE, + &mem_a, rntm + ); + + pack_a_buffer = ( uint8_t* ) bli_mem_buffer( &mem_a ); + + ( ( packa_s32 )lcntx->packa_fun_ptr ) + ( + pack_a_buffer, + a, rs_a, cs_a, + 1, k, + &rs_a_use, &cs_a_use + ); + + a_use = pack_a_buffer; + } + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated ); + + b_use = (int8_t*) ( b + (jc_cur_loop * k_updated ) ); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + } + else if( mtag_b == PACK ) + { + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( int8_t ) * nc0_updated * k_updated; + + n_sub_updated = nc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + ( ( packb_s32 )lcntx->packb_fun_ptr ) + ( + ( ( int8_t* )pack_b_buffer + ( n_sub_updated * pc )), + ( ( ( int8_t* )b ) + ( rs_b * pc ) + (jc * cs_b)), + rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use + ); + } + + b_use = pack_b_buffer; + } + + post_ops_attr.post_op_c_i = 0; + post_ops_attr.post_op_c_j = jc; + post_ops_attr.rs_c_downscale = rs_c; + + lpgemv_m_one_u8s8s32os32 + ( + nc0, k, + a_use, rs_a_use, cs_a_use, mtag_a, + b_use, rs_b_use, cs_b_use, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + NR, KC, + n_sub_updated, + jc_cur_loop_rem, + post_op_list, + &post_ops_attr + ); + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } // jc loop + + // Release pack buffers. + if ( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) ) + { + bli_pba_release( rntm, &mem_b ); + } + if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + } +} +#endif + // B should always be packed. LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { @@ -78,6 +332,26 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) return; } +#ifdef BLIS_KERNELS_ZEN4 + + if( ( m == 1 ) || ( n == 1 ) ) + { + lpgemv_rowvar_u8s8s32os32( m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale ); + return; + } + +#endif + // Strides are updated based on matrix packing/reordering. const uint8_t* a_use = NULL; dim_t rs_a_use = rs_a; @@ -99,7 +373,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) siz_t mem_a_size_req = 0; // Pack buffer for B. - int8_t* pack_b_buffer_u8s8s32o32; + int8_t* pack_b_buffer; mem_t mem_b = BLIS_MEM_INITIALIZER; siz_t mem_b_size_req = 0; dim_t packb_min_NR = get_packb_u8s8s32o32_min_NR(); @@ -236,7 +510,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ); thread->comm[jc_work_id].sent_object = - bli_mem_buffer( &mem_b ); + bli_mem_buffer( &mem_b ); } // All threads in work group should wait till chief thread has @@ -247,8 +521,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) &thread->comm[jc_work_id] ); - pack_b_buffer_u8s8s32o32 = - ( int8_t* ) thread->comm[jc_work_id].sent_object; + pack_b_buffer = ( int8_t* ) thread->comm[jc_work_id].sent_object; // Compute the B panel per thread loop range for parallel // packing using ic_ways number of threads. Since atmost only @@ -269,9 +542,9 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { ( ( packb_s32 )lcntx->packb_fun_ptr ) ( - pack_b_buffer_u8s8s32o32 + ( jc_packb_start * kc0_updated ), + pack_b_buffer + ( jc_packb_start * kc0_updated ), ( b + ( rs_b * pc ) + ( cs_b * jc ) + - ( cs_b * jc_packb_start ) ), rs_b, + ( cs_b * jc_packb_start ) ), rs_b, cs_b, ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); @@ -288,7 +561,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) bli_thread_ocomm_id( &thread_ic ), &thread->comm[jc_work_id] ); - b_use = pack_b_buffer_u8s8s32o32; + b_use = pack_b_buffer; } else if ( mtag_b == REORDERED ) { @@ -324,7 +597,9 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) } // Matrix A packed and reordered code path is not triggerred - // currently since we do not support it yet. + // currently for row-major inputs since we do not support it yet. + // Pack is enabled for column-major inputs to transform into + // row-major inputs as kernel expects row storage format. if ( mtag_a == PACK ) { mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; @@ -339,12 +614,21 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ( ( packa_s32 )lcntx->packa_fun_ptr ) ( pack_a_buffer_u8s8s32o32, - ( a + ( rs_a * ic ) + pc ), rs_a, + ( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a, mc0, kc0, &rs_a_use, &cs_a_use ); a_use = pack_a_buffer_u8s8s32o32; - a_block_stride = kc0_updated; + + if( cs_a == 1 ) + { + a_block_stride = kc0_updated; + } + + else + { + a_block_stride = rs_a_use; + } } else if ( mtag_a == REORDERED ) { @@ -365,7 +649,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) for ( dim_t jr = 0; jr < nc0; jr += NR ) { - dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + dim_t nr0 = bli_min((nc0 - jr), NR); // Post ops meta attributes. post_ops_attr.post_op_c_i = ic; diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h index 1ceb833180..5073dbbabf 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,6 +47,17 @@ BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() return 16; } +typedef void (*pack_s4bf16)( + bfloat16 *, + const int8_t *, + const dim_t, + const dim_t, + dim_t *, + dim_t *, + lpgemm_pre_op*, + dim_t + ); + typedef void (*pack_bf16) ( bfloat16*, @@ -59,6 +70,19 @@ typedef void (*pack_bf16) dim_t* ); +typedef void (*pack_s4) + ( + int8_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t*, + lpgemm_pre_op* + ); + void packb_nr64_bf16bf16f32of32 ( bfloat16* pack_b_buffer_bf16bf16f32of32, @@ -71,6 +95,30 @@ void packb_nr64_bf16bf16f32of32 dim_t* cs_p ); +void packb_nr64_bf16s4f32of32 + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p, + lpgemm_pre_op* pre_op + ); + +void packsclb_nr64_bf16s4f32of32 + ( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off + ); void packa_mr16_bf16bf16f32of32 ( diff --git a/addon/aocl_gemm/kernels/f32f32f32/lpgemm_pack_f32.h b/addon/aocl_gemm/kernels/f32f32f32/lpgemm_pack_f32.h new file mode 100644 index 0000000000..3f799bba7a --- /dev/null +++ b/addon/aocl_gemm/kernels/f32f32f32/lpgemm_pack_f32.h @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#ifndef BLIS_GEMM_F32_PACKA +#define BLIS_GEMM_F32_PACKA + +void packa_mr16_f32f32f32of32_col_major + ( + float* pack_a_buffer, + const float* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); +#endif + + diff --git a/addon/aocl_gemm/kernels/lpgemm_eltwise_ops_kernels.h b/addon/aocl_gemm/kernels/lpgemm_eltwise_ops_kernels.h new file mode 100644 index 0000000000..7f5715e73f --- /dev/null +++ b/addon/aocl_gemm/kernels/lpgemm_eltwise_ops_kernels.h @@ -0,0 +1,81 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_LPGEMM_ELTWISE_OPS_KERN_H +#define BLIS_LPGEMM_ELTWISE_OPS_KERN_H + +#define LPGEMM_ELTWISE_OPS_KERNEL(A_type,B_type,LP_SFX) \ +void lpgemm_eltwise_ops_kernel_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t n0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ + ) \ + +LPGEMM_ELTWISE_OPS_KERNEL(bfloat16,float,bf16of32_6x64); +LPGEMM_ELTWISE_OPS_KERNEL(float,float,f32of32_6x64); + +#define LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(A_type,B_type,LP_SFX) \ +void lpgemm_eltwise_ops_kernel_ ## LP_SFX \ + ( \ + const dim_t n0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ + ) \ + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_5x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_4x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_3x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_2x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_1x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_5x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_4x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_3x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_2x64); +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_1x64); + +#endif //BLIS_LPGEMM_ELTWISE_OPS_KERN_H diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index 83132e8fbf..673df4e527 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +38,13 @@ #include "lpgemm_post_ops.h" #include "aocl_bf16_type.h" +// Disable BF16 kernel in cases where compilers support other avx 512 +// features except BF16 ISA. +#if ( defined( BLIS_GCC ) && ( ( __GNUC__ < 11 ) || \ + ( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) ) +#define LPGEMM_BF16_JIT +#endif + typedef void (*lpgemm_m_fringe_f32_ker_ft) ( const dim_t k0, @@ -52,7 +59,7 @@ typedef void (*lpgemm_m_fringe_f32_ker_ft) const float alpha, const float beta, lpgemm_post_op* post_ops_list, - lpgemm_post_op_attr post_ops_attr + lpgemm_post_op_attr post_ops_attr ); #define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \ @@ -80,11 +87,13 @@ void lpgemm_rowvar_ ## LP_SFX \ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32); LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64); +LPGEMM_MAIN_KERN(bfloat16,int8_t,float,bf16s4f32of32_6x64m); LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m); LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m); LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64); LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32); + #define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -119,6 +128,12 @@ LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64); LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64); LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64); +LPGEMM_M_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_5x64); +LPGEMM_M_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_4x64); +LPGEMM_M_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_3x64); +LPGEMM_M_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_2x64); +LPGEMM_M_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_1x64); + LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64); LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64); LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64); @@ -170,6 +185,29 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32); LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32); LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32); + +#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ + ) \ + +LPGEMM_N_LT_NR0_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4xlt16 ); + + #define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -202,6 +240,10 @@ LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16); LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32); LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48); +LPGEMM_N_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_6x16m); +LPGEMM_N_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_6x32m); +LPGEMM_N_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_6x48m); + LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m); LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m); LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m); @@ -242,6 +284,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16); LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16); LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16); +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_6xlt16m); + LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16); @@ -301,6 +345,22 @@ LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48); LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48); LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_5x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_4x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_3x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_2x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_1x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_5x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_4x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_3x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_2x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_1x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_5x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_4x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_3x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_2x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_1x48); + LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16); LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16); LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16); @@ -356,6 +416,12 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,int8_t,float,bf16s4f32of32_1xlt16); + LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16); @@ -366,4 +432,66 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16); +#define LPGEMV_M_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemv_m_one_ ## LP_SFX \ +( \ + const dim_t n0, \ + const dim_t k, \ + const A_type *a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type *b, \ + dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type *c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + dim_t NR, \ + const dim_t KC, \ + const dim_t n_sub_updated, \ + const dim_t jc_cur_loop_rem, \ + lpgemm_post_op *post_op, \ + lpgemm_post_op_attr *post_op_attr \ + ) \ + +LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32); +LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32); +LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32); +LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32); + +#define LPGEMV_N_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemv_n_one_ ## LP_SFX \ +( \ + const dim_t m0, \ + const dim_t k, \ + const A_type *a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type *b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type *c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t MR, \ + const dim_t KC, \ + lpgemm_post_op *post_op, \ + lpgemm_post_op_attr *post_op_attr \ +) \ + +LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32); +LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32); +LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32); +LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int16_t,u8s8s16os16); +LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32); +LPGEMV_N_EQ1_KERN(int8_t,int8_t,int16_t,s8s8s16os16); + #endif //BLIS_LPGEMM_KERN_H diff --git a/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h b/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h index 7849e5a537..fdc2cc98b7 100644 --- a/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h index 661c153436..da3e5c62df 100644 --- a/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h +++ b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h @@ -53,6 +53,7 @@ typedef void (*packb_s32_s8) const dim_t, const dim_t, const dim_t, + const dim_t, dim_t*, dim_t* ); @@ -62,11 +63,12 @@ void packb_nr64_s8s8s32os32 int8_t* pack_b_buffer_s8s8s32o32, int32_t* pack_b_column_sum, const int8_t* b, - const dim_t ldb, + const dim_t rs_b, + const dim_t cs_b, const dim_t NC, const dim_t KC, - dim_t* rs_b, - dim_t* cs_b + dim_t* rs_p, + dim_t* cs_p ); #endif //BLIS_GEMM_INT8_PACKB_S8 diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packa_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packa_s16.h new file mode 100644 index 0000000000..a94a5aa132 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packa_s16.h @@ -0,0 +1,62 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_U8S8S16_PACKA +#define BLIS_GEMM_INT8_U8S8S16_PACKA + +typedef void (*packa_s16) + ( + uint8_t*, + const uint8_t*, + const dim_t, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* + ); + +void packa_u8s8s16os16 + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +#endif //BLIS_GEMM_INT8_U8S8S16_PACKA diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h index d0d507cbfb..3498c50688 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,15 +42,17 @@ typedef void (*packa_s32) const dim_t, const dim_t, const dim_t, + const dim_t, dim_t*, dim_t* ); -void packa_k64_u8s8s32o32 +void packa_u8s8s32os32 ( uint8_t* pack_a_buffer_u8s8s32o32, const uint8_t* a, - const dim_t lda, + const dim_t rs, + const dim_t cs, const dim_t MC, const dim_t KC, dim_t* rs_a, diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h index 2849cc8c33..d5246316ef 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h @@ -52,6 +52,7 @@ typedef void (*packb_s32) const dim_t, const dim_t, const dim_t, + const dim_t, dim_t*, dim_t* ); @@ -60,11 +61,24 @@ void packb_nr64_u8s8s32o32 ( int8_t* pack_b_buffer_u8s8s32o32, const int8_t* b, - const dim_t ldb, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); + +void packb_nr64_u8s4s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, const dim_t NC, const dim_t KC, - dim_t* rs_b, - dim_t* cs_b + dim_t* rs_p, + dim_t* cs_p ); #endif //BLIS_GEMM_INT8_PACKB diff --git a/addon/gemmd/gemmd.h b/addon/gemmd/gemmd.h index cab61bd181..2aeca7fd71 100644 --- a/addon/gemmd/gemmd.h +++ b/addon/gemmd/gemmd.h @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of copyright holder(s) nor the names + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/aocl_dtl/CMakeLists.txt b/aocl_dtl/CMakeLists.txt index 5b69f0e116..ec55db21f1 100644 --- a/aocl_dtl/CMakeLists.txt +++ b/aocl_dtl/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Collect all subdirectory paths that have at least one file with suffix in AOCLDTL_SRC_SUFS list. get_filepaths_with_suffixes(LOCAL_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR} "${AOCLDTL_SRC_SUFS}") @@ -50,10 +82,8 @@ elseif(THREADING_MODEL STREQUAL "pthreads") # in get-noopt-cflags-for target_compile_options(AOCL_DTL PRIVATE ${CTHREADFLAGS}) endif() -if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(AOCL_DTL PROPERTIES POSITION_INDEPENDENT_CODE ON) -endif() +# Equivalent to CPICFLAGS in get-noopt-cflags-for +set_target_properties(AOCL_DTL PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(AOCL_DTL flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(AOCL_DTL PROPERTIES FOLDER object-libs-targets) diff --git a/aocl_dtl/aocldtl_blis.c b/aocl_dtl/aocldtl_blis.c old mode 100755 new mode 100644 index 90be337f26..80c87b3650 --- a/aocl_dtl/aocldtl_blis.c +++ b/aocl_dtl/aocldtl_blis.c @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -503,6 +503,7 @@ void AOCL_DTL_log_her_sizes(int8 loglevel, void AOCL_DTL_log_dotv_sizes(int8 loglevel, char dt_type, + const f77_char conjx, const f77_int n, const f77_int incx, const f77_int incy, @@ -512,8 +513,8 @@ void AOCL_DTL_log_dotv_sizes(int8 loglevel, { char buffer[256]; - // { n, incx, incy} - sprintf(buffer, "%c %ld %ld %ld\n", dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); + // { conjx, n, incx, incy} + sprintf(buffer, "%c %c %ld %ld %ld\n", dt_type, conjx, (dim_t)n, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h old mode 100755 new mode 100644 index 275ad0a484..d1679d7ce4 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -298,6 +298,7 @@ void AOCL_DTL_log_axpy_sizes ( int8 loglevel, void AOCL_DTL_log_dotv_sizes( int8 loglevel, char dt_type, + const f77_char conjx, const f77_int n, const f77_int incx, const f77_int incy, @@ -517,9 +518,9 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ __FUNCTION__, __LINE__); -#define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, n, incx, incy) \ +#define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, conjx, n, incx, incy) \ if (gbIsLoggingEnabled) \ - AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ + AOCL_DTL_log_dotv_sizes(loglevel, dt_type, conjx, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ #define AOCL_DTL_LOG_SYR2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ if (gbIsLoggingEnabled) \ @@ -607,7 +608,7 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_AXPY_INPUTS(loglevel, dt_type, n, alpha, incx, incy) -#define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, n, incx, incy) +#define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, conjx, n, incx, incy) #define AOCL_DTL_LOG_SYR2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) diff --git a/aocl_dtl/aoclfal.c b/aocl_dtl/aoclfal.c index e96a42cf7c..b9eabe228f 100644 --- a/aocl_dtl/aoclfal.c +++ b/aocl_dtl/aoclfal.c @@ -1,9 +1,9 @@ /*=================================================================== * File Name : aoclfal.c * - * Description : Platform/os independed file handling API's + * Description : Platform/os independent file handling API's * - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aoclfal.h b/aocl_dtl/aoclfal.h index c37b699be9..f14fb6e62e 100644 --- a/aocl_dtl/aoclfal.h +++ b/aocl_dtl/aoclfal.h @@ -1,10 +1,10 @@ /*=================================================================== * File Name : aoclfal.h * - * Description : Interfaces for platform/os independed file + * Description : Interfaces for platform/os independent file * handling API's * - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index 4c6fed1140..d9106b8adc 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -1,104 +1,211 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ -add_definitions(-DBLAS="AOCL") -add_definitions(-DN_REPEAT=1000) -add_definitions(-DINT_FS="%lld") -add_definitions(-DUINT_FS="%llu") + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. -add_executable(BenchAmaxv bench_amaxv.c) -target_link_libraries(BenchAmaxv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchAmaxv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchAmaxv optimized "${LIB_NAME}.lib") + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. -add_executable(BenchAxpbyv bench_axpbyv.c) -target_link_libraries(BenchAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchAxpbyv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchAxpbyv optimized "${LIB_NAME}.lib") + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. -add_executable(BenchCopyv bench_copyv.c) -target_link_libraries(BenchCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchCopyv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchCopyv optimized "${LIB_NAME}.lib") + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -add_executable(BenchDotv bench_dotv.c) -target_link_libraries(BenchDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchDotv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchDotv optimized "${LIB_NAME}.lib") +]=] -add_executable(BenchGemm bench_gemm.c) -target_link_libraries(BenchGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchGemm OpenMP::OpenMP_CXX) +# Comments: +# Set the path to the BLIS installation. +set(BLIS_INSTALL_PATH "" CACHE STRING "Setting the path to a BLIS installation that needs testing.") +if(BLIS_INSTALL_PATH) + message(STATUS "BLIS_INSTALL_PATH :" ${BLIS_INSTALL_PATH}) endif() -target_link_libraries(BenchGemm optimized "${LIB_NAME}.lib") -add_executable(BenchGemmt bench_gemmt.c) -target_link_libraries(BenchGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchGemmt OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchGemmt optimized "${LIB_NAME}.lib") +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. -add_executable(BenchGemv bench_gemv.c) -target_link_libraries(BenchGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchGemv OpenMP::OpenMP_CXX) +#if(NOT DEFINED BLIS_INSTALL_PATH) +if(BLIS_INSTALL_PATH STREQUAL "") + set(DIST_PATH ${CMAKE_BINARY_DIR}) + set(LIB_PATH ${DIST_PATH}/lib/${BLIS_CONFIG_FAMILY}) + set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) + set(CINFLAGS ${INC_PATH}) + set(LIBBLIS ${libblis_link}) +else() + set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/${BLIS_CONFIG_FAMILY}) + set(CINFLAGS ${INC_PATH}) + # Set up the library name. + if(WIN32) + set(LIB_BLIS AOCL-LibBlis-Win) + else() + set(LIB_BLIS ${libblis_link}) + endif() + # Append if threading is required. + if(NOT (ENABLE_THREADING STREQUAL "no")) + if(WIN32) + string(APPEND LIB_BLIS -MT) + else() + string(APPEND LIB_BLIS -mt) + endif() + endif() + # Append for dll if necessary. + if(WIN32 AND BUILD_SHARED_LIBS) + string(APPEND LIB_BLIS -dll) + endif() + # Setting the suffix for find_library(). + if(WIN32) + set(LIB_BLIS .lib) + else() + if(BUILD_SHARED_LIBS) + string(APPEND LIB_BLIS .so) + else() + string(APPEND LIB_BLIS .a) + endif() + endif() + set(LIBBLIS ${LIB_PATH}/${LIB_BLIS}) + message(STATUS "BLIS_INSTALL_PATH : " ${LIBBLIS}) endif() -target_link_libraries(BenchGemv optimized "${LIB_NAME}.lib") -add_executable(BenchGer bench_ger.c) -target_link_libraries(BenchGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchGer OpenMP::OpenMP_CXX) +if(WIN32) + set(LIBSUFFIX dll) +else() + set(LIBSUFFIX so) endif() -target_link_libraries(BenchGer optimized "${LIB_NAME}.lib") -add_executable(BenchNrm2 bench_nrm2.c) -target_link_libraries(BenchNrm2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchNrm2 OpenMP::OpenMP_CXX) +set(NREPEATS "1000" CACHE STRING "Set no. of times loop repeats.") +set(MKL_PATH $ENV{MKLROOT} CACHE STRING "Set MKL_PATH.") +if(THREADING_MODEL STREQUAL "no") + set(MKL_THREAD "${MKL_PATH}/libmkl_sequential.${LIBSUFFIX}") +else() + set(MKL_THREAD "${MKL_PATH}/libmkl_gnu_thread.${LIBSUFFIX}") + set(MKL_OMP iomp5) endif() -target_link_libraries(BenchNrm2 optimized "${LIB_NAME}.lib") +set(INTEL_LP64 "${MKL_PATH}/libmkl_intel_lp64.${LIBSUFFIX}") +set(MKL_CORE "${MKL_PATH}/libmkl_core.${LIBSUFFIX}") +set(COMMON_LIBS pthread m dl ${MKL_OMP}) +set(MKL_LIB ${INTEL_LP64} ${MKL_CORE} ${MKL_THREAD} ${COMMON_LIBS}) +set(OPENBLAS_PATH "/home/amd/mylibs/openblas" CACHE STRING "Set OPENBLAS_PATH.") +set(OPENBLAS_LIB "${OPENBLAS_PATH}/libopenblas.${LIBSUFFIX}") +set(ATLAS_PATH "/home/amd/mylibs/atlas" CACHE STRING "Set ATLAS_PATH.") +set(F77BLAS_LIB "${ATLAS_PATH}/libf77blas.${LIBSUFFIX}") +set(ATLAS_LIB "${ATLAS_PATH}/libatlas.${LIBSUFFIX}") +set(ATLAS_LIB ${ATLAS_LIB} ${F77BLAS_LIB}) -add_executable(BenchScalv bench_scalv.c) -target_link_libraries(BenchScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchScalv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchScalv optimized "${LIB_NAME}.lib") -add_executable(BenchSwapv bench_swapv.c) -target_link_libraries(BenchSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchSwapv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchSwapv optimized "${LIB_NAME}.lib") +# Include the corresponding make_defs.cmake that holds the required compiler options. +include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) -add_executable(BenchSyrk bench_syrk.c) -target_link_libraries(BenchSyrk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchSyrk OpenMP::OpenMP_CXX) -endif() -target_link_libraries(BenchSyrk optimized "${LIB_NAME}.lib") +# Gather all local source files. +file(GLOB file_list LIST_DIRECTORIES false RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/" "*.c") -add_executable(BenchTrsm bench_trsm.c) -target_link_libraries(BenchTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchTrsm OpenMP::OpenMP_CXX) +# Defining the format specifiers to read long double value from input file using fscanf +if (WIN32 AND ((INT_SIZE STREQUAL "auto") OR (INT_SIZE STREQUAL "64"))) + set(BENCH_FLAGS -DN_REPEAT=${NREPEATS} -DINT_FS="%lld" -DUINT_FS="%llu") +elseif ((INT_SIZE STREQUAL "auto") OR (INT_SIZE STREQUAL "64")) + set(BENCH_FLAGS -DN_REPEAT=${NREPEATS} -DINT_FS="%ld" -DUINT_FS="%lu") +else() + set(BENCH_FLAGS -DN_REPEAT=${NREPEATS} -DINT_FS="%d" -DUINT_FS="%u") endif() -target_link_libraries(BenchTrsm optimized "${LIB_NAME}.lib") -add_executable(BenchTrsv bench_trsv.c) -target_link_libraries(BenchTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(BenchTrsv OpenMP::OpenMP_CXX) +# Create an executable using the sources above. +function(benchexe extn) + set(dblas "aocl") + if(extn STREQUAL "mkl") + set(BLAS_LIBS ${MKL_LIB}) + set(dblas ${extn}) + elseif(extn STREQUAL "openblas") + set(BLAS_LIBS ${OPENBLAS_LIB}) + set(dblas ${extn}) + elseif(extn STREQUAL "atlas") + set(BLAS_LIBS ${ATLAS_LIB}) + set(dblas ${extn}) + endif() + set(BENCH_FLAGS "${BENCH_FLAGS}" -DBLAS="${dblas}") + foreach(src ${file_list}) + string(REGEX REPLACE ".c$" "" exec_name ${src}) + set(exec_name "${exec_name}_${extn}") + add_executable(${exec_name}.x ${src}) + target_compile_options(${exec_name}.x + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + ) + if(WIN32 AND BUILD_SHARED_LIBS) + target_compile_definitions(${exec_name}.x + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + "-DBLIS_EXPORT=__declspec(dllimport)" + ${BENCH_FLAGS} + ) + else() + target_compile_definitions(${exec_name}.x + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + ${BENCH_FLAGS} + ) + endif() + target_include_directories(${exec_name}.x + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + target_link_libraries(${exec_name}.x PRIVATE ${BLAS_LIBS} ${LIBBLIS} ${LDFLAGS}) + if(THREADING_MODEL STREQUAL "openmp") + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(${exec_name}.x PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(${exec_name}.x PRIVATE OpenMP::OpenMP_C) + endif() + endif() + list(APPEND temp_executables ${exec_name}.x) + endforeach() + set(bench_executables ${temp_executables} PARENT_SCOPE) +endfunction() + +benchexe("blis") +add_custom_target(bench_blis DEPENDS ${bench_executables}) +benchexe("mkl") +add_custom_target(bench_mkl DEPENDS ${bench_executables}) +benchexe("openblas") +add_custom_target(bench_openblas DEPENDS ${bench_executables}) +benchexe("atlas") +add_custom_target(bench_atlas DEPENDS ${bench_executables}) +add_custom_target(benchmark DEPENDS bench_blis bench_mkl bench_openblas) + +# Put all those targets under bench-targets folder name so that they appear all together in IDE. +# NOTE : To run bench for atlas, add bench_atlas to the bench-targets +set_target_properties(benchmark bench_blis bench_mkl bench_openblas PROPERTIES FOLDER bench-targets) + +# Add bench_aocl_gemm only if aocl_gemm is in the ENABLE_ADDON list. +# This needs to work in cases where both aocl_gemm and gemmd are requested. +# lpgemm_index will be -1 if it's not found in ENABLE_ADDON list. +list(FIND ENABLE_ADDON "aocl_gemm" lpgemm_index) +if(NOT (lpgemm_index STREQUAL -1)) + add_subdirectory(bench_aocl_gemm EXCLUDE_FROM_ALL) endif() -target_link_libraries(BenchTrsv optimized "${LIB_NAME}.lib") diff --git a/bench/Makefile b/bench/Makefile old mode 100755 new mode 100644 index cc1b7297dc..aeeb6f615b --- a/bench/Makefile +++ b/bench/Makefile @@ -1,4 +1,3 @@ - # # # BLIS @@ -6,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -194,6 +193,7 @@ blis: \ bench_copyv_blis.x \ bench_swapv_blis.x \ bench_axpbyv_blis.x \ + bench_axpyv_blis.x \ bench_gemm_pack_compute_blis.x openblas: \ @@ -210,7 +210,8 @@ openblas: \ bench_amaxv_openblas.x \ bench_copyv_openblas.x \ bench_swapv_openblas.x \ - bench_axpbyv_openblas.x + bench_axpbyv_openblas.x \ + bench_axpyv_openblas.x atlas: \ bench_gemm_atlas.x \ @@ -225,7 +226,8 @@ atlas: \ bench_amaxv_atlas.x \ bench_copyv_atlas.x \ bench_swapv_atlas.x \ - bench_axpbyv_atlax.x + bench_axpbyv_atlas.x \ + bench_axpyv_atlas.x mkl: \ bench_gemm_mkl.x \ @@ -242,6 +244,7 @@ mkl: \ bench_copyv_mkl.x \ bench_swapv_mkl.x \ bench_axpbyv_mkl.x \ + bench_axpyv_mkl.x \ bench_gemm_pack_compute_mkl.x diff --git a/bench/bench_amaxv.c b/bench/bench_amaxv.c index c4df0cd4d7..d803a36ec8 100644 --- a/bench/bench_amaxv.c +++ b/bench/bench_amaxv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -94,7 +94,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t n\t incx\t gflops\n"); + fprintf(fout, "Func Dt n incx max_index gflops\n"); dim_t n; inc_t incx; diff --git a/bench/bench_aocl_gemm/CMakeLists.txt b/bench/bench_aocl_gemm/CMakeLists.txt new file mode 100644 index 0000000000..64380b8744 --- /dev/null +++ b/bench/bench_aocl_gemm/CMakeLists.txt @@ -0,0 +1,95 @@ +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] + +# Comments: +# Gather all local source files. +file(GLOB file_list LIST_DIRECTORIES false RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/" "*.c") + +# Defining the format specifiers to read long double value from input file using fscanf +if (WIN32 AND ((INT_SIZE STREQUAL "auto") OR (INT_SIZE STREQUAL "64"))) + set(LPGEMM_FLAGS -DBLAS="aocl" -DN_REPEAT=${NREPEATS} -DINT_FS="%lld" -DUINT_FS="%llu") +elseif ((INT_SIZE STREQUAL "auto") OR (INT_SIZE STREQUAL "64")) + set(LPGEMM_FLAGS -DBLAS="aocl" -DN_REPEAT=${NREPEATS} -DINT_FS="%ld" -DUINT_FS="%lu") +else() + set(LPGEMM_FLAGS -DBLAS="aocl" -DN_REPEAT=${NREPEATS} -DINT_FS="%d" -DUINT_FS="%u") +endif() + +# Create an executable using the sources above. +function(lpgemmbenchexe extn) + foreach(src ${file_list}) + string(REGEX REPLACE ".c$" "" exec_name ${src}) + set(exec_name "${exec_name}_${extn}") + add_executable(${exec_name}.x ${src}) + target_compile_options(${exec_name}.x + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + ) + if(WIN32 AND BUILD_SHARED_LIBS) + target_compile_definitions(${exec_name}.x + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + "-DBLIS_EXPORT=__declspec(dllimport)" + ${LPGEMM_FLAGS} + ) + else() + target_compile_definitions(${exec_name}.x + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + ${LPGEMM_FLAGS} + ) + endif() + target_include_directories(${exec_name}.x + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + target_link_libraries(${exec_name}.x PRIVATE ${LIBBLIS} ${LDFLAGS}) + if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(${exec_name}.x PRIVATE OpenMP::OpenMP_C) + endif() + list(APPEND temp_executables ${exec_name}.x) + endforeach() + set(bench_executables ${temp_executables} PARENT_SCOPE) +endfunction() + +lpgemmbenchexe("blis") +add_custom_target(lpgemm_blis DEPENDS ${bench_executables}) +add_custom_target(benchmark_lpgemm DEPENDS lpgemm_blis) + +# Put all those targets under bench_aocl_gemm-targets folder name so that they appear all together in IDE. +set_target_properties(benchmark_lpgemm lpgemm_blis PROPERTIES FOLDER bench_aocl_gemm-targets) diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile old mode 100755 new mode 100644 index 897a982ba3..c8c2b732a1 --- a/bench/bench_aocl_gemm/Makefile +++ b/bench/bench_aocl_gemm/Makefile @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -30,6 +30,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # Makefile for lpgemm bench. # @@ -107,7 +108,8 @@ all: blis blis: \ bench_lpgemm_blis.x \ - bench_lpgemm_utils_blis.x + bench_lpgemm_utils_blis.x \ + bench_lpgemm_eltwise_ops_blis.x # --Object file rules -- diff --git a/bench/bench_aocl_gemm/bench_eltwise_ops_input.txt b/bench/bench_aocl_gemm/bench_eltwise_ops_input.txt new file mode 100644 index 0000000000..4ffe331451 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_eltwise_ops_input.txt @@ -0,0 +1,36 @@ +r n n 577 2057 2057 2057 *:scale=vector,zp=scalar,bias +r n n 577 2064 2064 2064 *:scale=vector,zp=scalar,bias +r n n 577 2080 2080 2080 *:bias +r n n 577 2096 2096 2096 *:scale=scalar,zp=vector,bias +r n n 577 2112 2112 2112 *:bias +r n n 577 2067 2067 2067 *:scale=vector,zp=scalar,bias +r n n 577 2085 2085 2085 *:bias +r n n 577 2099 2099 2099 *:bias +r n n 577 2118 2118 2118 *:scale=vector,zp=scalar,bias,gelu_tanh,clip +r n n 578 2057 2057 2057 *:bias,gelu_tanh,clip +r n n 578 2064 2064 2064 *:bias,gelu_tanh,clip +r n n 578 2080 2080 2080 *:bias,gelu_tanh,clip +r n n 578 2096 2096 2096 *:scale=vector,zp=scalar,bias,gelu_tanh,clip +r n n 578 2112 2112 2112 *:bias,gelu_tanh,clip +r n n 578 2067 2067 2067 *:scale=scalar,zp=vector,bias,gelu_tanh,clip +r n n 578 2085 2085 2085 *:bias,gelu_tanh,clip +r n n 578 2099 2099 2099 *:scale=scalar,zp=vector,bias,gelu_tanh,clip +r n n 578 2118 2118 2118 *:bias,gelu_tanh,clip +r n n 579 2057 2057 2057 *:scale=scalar,zp=scalar,bias,gelu_tanh,clip +r n n 579 2064 2064 2064 *:bias,gelu_tanh,clip +r n n 579 2080 2080 2080 *:scale=vector,zp=vector,bias,gelu_tanh,clip +r n n 579 2096 2096 2096 *:bias,gelu_tanh,clip +r n n 579 2112 2112 2112 *:scale=vector,zp=vector,bias,gelu_tanh,clip +r n n 579 2067 2067 2067 *:bias,gelu_tanh,clip +r n n 579 2085 2085 2085 *:scale=scalar,zp=scalar,bias,gelu_tanh,clip +r n n 579 2099 2099 2099 *:bias,gelu_tanh,clip +r n n 579 2118 2118 2118 *:bias,gelu_tanh,clip +r n n 581 2057 2057 2057 *:bias,clip +r n n 581 2064 2064 2064 *:scale=scalar,zp=vector,bias,clip +r n n 581 2080 2080 2080 *:bias,clip +r n n 581 2096 2096 2096 *:scale=scalar,zp=scalar,bias,clip +r n n 581 2112 2112 2112 *:scale=scalar,zp=scalar,bias,clip +r n n 581 2067 2067 2067 *:scale=scalar,zp=scalar,bias,clip +r n n 581 2085 2085 2085 *:bias,clip +r n n 581 2099 2099 2099 *:scale=vector,zp=vector,bias,clip +r n n 581 2118 2118 2118 *:bias,clip diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index fbde59de5a..a92ed6f75c 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,3782 +1,27 @@ -u s32 r n n n p 480 20 2050 2050 20 20 none -u s8 r n n n p 480 20 2050 2050 20 20 none -u s32 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n p 481 20 2050 2050 20 20 none -u s8 r n n n p 481 20 2050 2050 20 20 none -u s32 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n p 482 20 2050 2050 20 20 none -u s8 r n n n p 482 20 2050 2050 20 20 none -u s32 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n p 483 20 2050 2050 20 20 none -u s8 r n n n p 483 20 2050 2050 20 20 none -u s32 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n R 484 20 2050 2050 20 20 none -u s8 r n n n R 484 20 2050 2050 20 20 none -u s32 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n R 485 20 2050 2050 20 20 none -u s8 r n n n R 485 20 2050 2050 20 20 none -u s32 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -u s8 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -u s32 r n n n R 480 39 2050 2050 39 39 none -u s8 r n n n R 480 39 2050 2050 39 39 none -u s32 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n R 481 39 2050 2050 39 39 none -u s8 r n n n R 481 39 2050 2050 39 39 none -u s32 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n R 482 39 2050 2050 39 39 none -u s8 r n n n R 482 39 2050 2050 39 39 none -u s32 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n R 483 39 2050 2050 39 39 none -u s8 r n n n R 483 39 2050 2050 39 39 none -u s32 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n R 484 39 2050 2050 39 39 none -u s8 r n n n R 484 39 2050 2050 39 39 none -u s32 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n p 485 39 2050 2050 39 39 none -u s8 r n n n p 485 39 2050 2050 39 39 none -u s32 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -u s8 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -u s32 r n n n p 480 50 2050 2050 50 50 none -u s8 r n n n p 480 50 2050 2050 50 50 none -u s32 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n p 481 50 2050 2050 50 50 none -u s8 r n n n p 481 50 2050 2050 50 50 none -u s32 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n p 482 50 2050 2050 50 50 none -u s8 r n n n p 482 50 2050 2050 50 50 none -u s32 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n p 483 50 2050 2050 50 50 none -u s8 r n n n p 483 50 2050 2050 50 50 none -u s32 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n p 484 50 2050 2050 50 50 none -u s8 r n n n p 484 50 2050 2050 50 50 none -u s32 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n p 485 50 2050 2050 50 50 none -u s8 r n n n p 485 50 2050 2050 50 50 none -u s32 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -u s8 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -u s32 r n n n R 480 1108 2050 2050 1108 1108 none -u s8 r n n n R 480 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 481 1108 2050 2050 1108 1108 none -u s8 r n n n R 481 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 482 1108 2050 2050 1108 1108 none -u s8 r n n n R 482 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 483 1108 2050 2050 1108 1108 none -u s8 r n n n R 483 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 484 1108 2050 2050 1108 1108 none -u s8 r n n n R 484 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 485 1108 2050 2050 1108 1108 none -u s8 r n n n R 485 1108 2050 2050 1108 1108 none -u s32 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -u s8 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -u s32 r n n n R 480 1127 2050 2050 1127 1127 none -u s8 r n n n R 480 1127 2050 2050 1127 1127 none -u s32 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n R 481 1127 2050 2050 1127 1127 none -u s8 r n n n R 481 1127 2050 2050 1127 1127 none -u s32 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n R 482 1127 2050 2050 1127 1127 none -u s8 r n n n R 482 1127 2050 2050 1127 1127 none -u s32 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n R 483 1127 2050 2050 1127 1127 none -u s8 r n n n R 483 1127 2050 2050 1127 1127 none -u s32 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n p 484 1127 2050 2050 1127 1127 none -u s8 r n n n p 484 1127 2050 2050 1127 1127 none -u s32 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n p 485 1127 2050 2050 1127 1127 none -u s8 r n n n p 485 1127 2050 2050 1127 1127 none -u s32 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -u s8 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -u s32 r n n n p 480 1138 2050 2050 1138 1138 none -u s8 r n n n p 480 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 481 1138 2050 2050 1138 1138 none -u s8 r n n n p 481 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 482 1138 2050 2050 1138 1138 none -u s8 r n n n p 482 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 483 1138 2050 2050 1138 1138 none -u s8 r n n n p 483 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 484 1138 2050 2050 1138 1138 none -u s8 r n n n p 484 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 485 1138 2050 2050 1138 1138 none -u s8 r n n n p 485 1138 2050 2050 1138 1138 none -u s32 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -u s8 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -u s32 r n n n p 1 1 3 3 1 1 none -u s8 r n n n p 1 1 3 3 1 1 none -u s32 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip -u s8 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip -u s32 r n n n p 1 9 3 3 9 9 none -u s8 r n n n p 1 9 3 3 9 9 none -u s32 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip -u s8 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip -u s32 r n n n p 1 2048 3 3 2048 2048 none -u s8 r n n n p 1 2048 3 3 2048 2048 none -u s32 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -u s8 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -u s32 r n n n p 1 2048 5192 5192 2048 2048 none -u s8 r n n n p 1 2048 5192 5192 2048 2048 none -u s32 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -u s8 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -u s32 r n n n p 9 1 3 3 1 1 none -u s8 r n n n p 9 1 3 3 1 1 none -u s32 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip -u s8 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip -u s32 r n n n p 576 1 3500 3500 1 1 none -u s8 r n n n p 576 1 3500 3500 1 1 none -u s32 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -u s8 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -u s32 r n n n p 1 1 1 1 1 1 none -u s8 r n n n p 1 1 1 1 1 1 none -u s32 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip -u s8 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip -u s32 r n n n p 102 1088 1024 1024 1088 1088 none -u s8 r n n n p 102 1088 1024 1024 1088 1088 none -u s32 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -u s8 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -u s32 r n n n p 102 2048 1024 1024 2048 2048 none -u s8 r n n n p 102 2048 1024 1024 2048 2048 none -u s32 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -u s8 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -u s32 r n n n p 485 656 1024 1024 656 656 none -u s8 r n n n p 485 656 1024 1024 656 656 none -u s32 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -u s8 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -u s32 r n n n p 483 656 1024 1024 656 656 none -u s8 r n n n p 483 656 1024 1024 656 656 none -u s32 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -u s8 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -u s32 r n n n p 81 128 3 3 128 128 none -u s8 r n n n p 81 128 3 3 128 128 none -u s32 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip -u s8 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip -u s32 r n n n p 1022 512 515 515 512 512 none -u s8 r n n n p 1022 512 515 515 512 512 none -u s32 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -u s8 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -u s32 r n n n p 74 512 515 515 512 512 none -u s8 r n n n p 74 512 515 515 512 512 none -u s32 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip -u s8 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip -u s32 r n n n p 253 2048 515 515 2048 2048 none -u s8 r n n n p 253 2048 515 515 2048 2048 none -u s32 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -u s8 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -u s32 r n n n p 8192 1040 515 515 1040 1040 none -u s8 r n n n p 8192 1040 515 515 1040 1040 none -u s32 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -u s8 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -u s32 r n n n p 10 1029 515 515 1029 1029 none -u s8 r n n n p 10 1029 515 515 1029 1029 none -u s32 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -u s8 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -u s32 r n n n p 24 1040 2050 2050 1040 1040 none -u s8 r n n n p 24 1040 2050 2050 1040 1040 none -u s32 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -u s8 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -u s32 r n n n p 1024 1029 2050 2050 1029 1029 none -u s8 r n n n p 1024 1029 2050 2050 1029 1029 none -u s32 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -u s8 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -u s32 r n n n p 480 660 2050 2050 660 660 none -u s8 r n n n p 480 660 2050 2050 660 660 none -u s32 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 481 660 2050 2050 660 660 none -u s8 r n n n p 481 660 2050 2050 660 660 none -u s32 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 482 660 2050 2050 660 660 none -u s8 r n n n p 482 660 2050 2050 660 660 none -u s32 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 483 660 2050 2050 660 660 none -u s8 r n n n p 483 660 2050 2050 660 660 none -u s32 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 484 660 2050 2050 660 660 none -u s8 r n n n p 484 660 2050 2050 660 660 none -u s32 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 485 660 2050 2050 660 660 none -u s8 r n n n p 485 660 2050 2050 660 660 none -u s32 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -u s8 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -u s32 r n n n p 480 679 2050 2050 679 679 none -u s8 r n n n p 480 679 2050 2050 679 679 none -u s32 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 481 679 2050 2050 679 679 none -u s8 r n n n p 481 679 2050 2050 679 679 none -u s32 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 482 679 2050 2050 679 679 none -u s8 r n n n p 482 679 2050 2050 679 679 none -u s32 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 483 679 2050 2050 679 679 none -u s8 r n n n p 483 679 2050 2050 679 679 none -u s32 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 484 679 2050 2050 679 679 none -u s8 r n n n p 484 679 2050 2050 679 679 none -u s32 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 485 679 2050 2050 679 679 none -u s8 r n n n p 485 679 2050 2050 679 679 none -u s32 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -u s8 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -u s32 r n n n p 480 690 2050 2050 690 690 none -u s8 r n n n p 480 690 2050 2050 690 690 none -u s32 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 481 690 2050 2050 690 690 none -u s8 r n n n p 481 690 2050 2050 690 690 none -u s32 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 482 690 2050 2050 690 690 none -u s8 r n n n p 482 690 2050 2050 690 690 none -u s32 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 483 690 2050 2050 690 690 none -u s8 r n n n p 483 690 2050 2050 690 690 none -u s32 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 484 690 2050 2050 690 690 none -u s8 r n n n p 484 690 2050 2050 690 690 none -u s32 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 485 690 2050 2050 690 690 none -u s8 r n n n p 485 690 2050 2050 690 690 none -u s32 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -u s8 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -u s32 r n n n p 480 660 2048 2048 660 660 none -u s8 r n n n p 480 660 2048 2048 660 660 none -u s32 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 481 660 2048 2048 660 660 none -u s8 r n n n p 481 660 2048 2048 660 660 none -u s32 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 482 660 2048 2048 660 660 none -u s8 r n n n p 482 660 2048 2048 660 660 none -u s32 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 483 660 2048 2048 660 660 none -u s8 r n n n p 483 660 2048 2048 660 660 none -u s32 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 484 660 2048 2048 660 660 none -u s8 r n n n p 484 660 2048 2048 660 660 none -u s32 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 485 660 2048 2048 660 660 none -u s8 r n n n p 485 660 2048 2048 660 660 none -u s32 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -u s8 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -u s32 r n n n p 480 679 2048 2048 679 679 none -u s8 r n n n p 480 679 2048 2048 679 679 none -u s32 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 481 679 2048 2048 679 679 none -u s8 r n n n p 481 679 2048 2048 679 679 none -u s32 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 482 679 2048 2048 679 679 none -u s8 r n n n p 482 679 2048 2048 679 679 none -u s32 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 483 679 2048 2048 679 679 none -u s8 r n n n p 483 679 2048 2048 679 679 none -u s32 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 484 679 2048 2048 679 679 none -u s8 r n n n p 484 679 2048 2048 679 679 none -u s32 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 485 679 2048 2048 679 679 none -u s8 r n n n p 485 679 2048 2048 679 679 none -u s32 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -u s8 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -u s32 r n n n p 480 690 2048 2048 690 690 none -u s8 r n n n p 480 690 2048 2048 690 690 none -u s32 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 481 690 2048 2048 690 690 none -u s8 r n n n p 481 690 2048 2048 690 690 none -u s32 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 482 690 2048 2048 690 690 none -u s8 r n n n p 482 690 2048 2048 690 690 none -u s32 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 483 690 2048 2048 690 690 none -u s8 r n n n p 483 690 2048 2048 690 690 none -u s32 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 484 690 2048 2048 690 690 none -u s8 r n n n p 484 690 2048 2048 690 690 none -u s32 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 485 690 2048 2048 690 690 none -u s8 r n n n p 485 690 2048 2048 690 690 none -u s32 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -u s8 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -u s32 r n n n p 480 656 1024 1024 656 656 none -u s8 r n n n p 480 656 1024 1024 656 656 none -u s32 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -u s8 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -u s32 r n n n p 480 128 3 3 128 128 none -u s8 r n n n p 480 128 3 3 128 128 none -u s32 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip -u s8 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip -u s32 r n n n p 1024 512 515 515 512 512 none -u s8 r n n n p 1024 512 515 515 512 512 none -u s32 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -u s8 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -u s32 r n n n p 1024 2048 1024 1024 2048 2048 none -u s8 r n n n p 1024 2048 1024 1024 2048 2048 none -u s32 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -u s8 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -u s32 r n n n p 1024 2048 515 515 2048 2048 none -u s8 r n n n p 1024 2048 515 515 2048 2048 none -u s32 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -u s8 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -u s32 r n n n p 1024 1040 515 515 1040 1040 none -u s8 r n n n p 1024 1040 515 515 1040 1040 none -u s32 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -u s8 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -u s32 r n n n p 5 1029 515 515 1029 1029 none -u s8 r n n n p 5 1029 515 515 1029 1029 none -u s32 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -u s8 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -u s32 r n n n p 1024 1029 515 515 1029 1029 none -u s8 r n n n p 1024 1029 515 515 1029 1029 none -u s32 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -u s8 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -u s32 r n n n p 1024 1040 2050 2050 1040 1040 none -u s8 r n n n p 1024 1040 2050 2050 1040 1040 none -u s32 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -u s8 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -u s32 r n n n p 1029 1029 2050 2050 1029 1029 none -u s8 r n n n p 1029 1029 2050 2050 1029 1029 none -u s32 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -u s8 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -u s32 r n n n R 480 646 2050 2050 646 646 none -u s8 r n n n R 480 646 2050 2050 646 646 none -u s32 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 481 646 2050 2050 646 646 none -u s8 r n n n R 481 646 2050 2050 646 646 none -u s32 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 482 646 2050 2050 646 646 none -u s8 r n n n R 482 646 2050 2050 646 646 none -u s32 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 483 646 2050 2050 646 646 none -u s8 r n n n R 483 646 2050 2050 646 646 none -u s32 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 484 646 2050 2050 646 646 none -u s8 r n n n R 484 646 2050 2050 646 646 none -u s32 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 485 646 2050 2050 646 646 none -u s8 r n n n R 485 646 2050 2050 646 646 none -u s32 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -u s8 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -u s32 r n n n R 481 656 2050 2050 656 656 none -u s8 r n n n R 481 656 2050 2050 656 656 none -u s32 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -u s8 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -u s32 r n n n R 482 656 2050 2050 656 656 none -u s8 r n n n R 482 656 2050 2050 656 656 none -u s32 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -u s8 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -u s32 r n n n R 483 656 2050 2050 656 656 none -u s8 r n n n R 483 656 2050 2050 656 656 none -u s32 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -u s8 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -u s32 r n n n R 484 656 2050 2050 656 656 none -u s8 r n n n R 484 656 2050 2050 656 656 none -u s32 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -u s8 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -u s32 r n n n p 485 656 2050 2050 656 656 none -u s8 r n n n p 485 656 2050 2050 656 656 none -u s32 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -u s8 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -u s32 r n n n p 480 672 2050 2050 672 672 none -u s8 r n n n p 480 672 2050 2050 672 672 none -u s32 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 481 672 2050 2050 672 672 none -u s8 r n n n p 481 672 2050 2050 672 672 none -u s32 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 482 672 2050 2050 672 672 none -u s8 r n n n p 482 672 2050 2050 672 672 none -u s32 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 483 672 2050 2050 672 672 none -u s8 r n n n p 483 672 2050 2050 672 672 none -u s32 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 484 672 2050 2050 672 672 none -u s8 r n n n p 484 672 2050 2050 672 672 none -u s32 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 485 672 2050 2050 672 672 none -u s8 r n n n p 485 672 2050 2050 672 672 none -u s32 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -u s8 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -u s32 r n n n p 480 688 2050 2050 688 688 none -u s8 r n n n p 480 688 2050 2050 688 688 none -u s32 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n p 481 688 2050 2050 688 688 none -u s8 r n n n p 481 688 2050 2050 688 688 none -u s32 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n r 482 688 2050 2050 688 688 none -u s8 r n n n r 482 688 2050 2050 688 688 none -u s32 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n r 483 688 2050 2050 688 688 none -u s8 r n n n r 483 688 2050 2050 688 688 none -u s32 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n r 484 688 2050 2050 688 688 none -u s8 r n n n r 484 688 2050 2050 688 688 none -u s32 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n r 485 688 2050 2050 688 688 none -u s8 r n n n r 485 688 2050 2050 688 688 none -u s32 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -u s8 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -u s32 r n n n r 1024 512 64 64 512 512 none -u s8 r n n n r 1024 512 64 64 512 512 none -u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s32 r n n n r 16 256 512 512 256 256 none -u s8 r n n n r 16 256 512 512 256 256 none -u s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -u s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -u s32 r n n n r 480 640 512 512 640 640 none -u s8 r n n n r 480 640 512 512 640 640 none -u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s32 r n n n r 64 768 512 512 768 768 none -u s8 r n n n r 64 768 512 512 768 768 none -u s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -u s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -u s32 r n n n r 128 128 128 128 128 128 none -u s8 r n n n r 128 128 128 128 128 128 none -u s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -u s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -u s32 r n n n r 1024 64 512 512 64 64 none -u s8 r n n n r 1024 64 512 512 64 64 none -u s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -u s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -u s32 r n n n r 1024 256 32 32 256 256 none -u s8 r n n n r 1024 256 32 32 256 256 none -u s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -u s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -u s32 r n n n r 1024 512 64 64 512 512 none -u s8 r n n n r 1024 512 64 64 512 512 none -u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s32 r n n n r 480 640 512 512 640 640 none -u s8 r n n n r 480 640 512 512 640 640 none -u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s32 r n n n p 1024 32 256 256 32 32 none -u s8 r n n n p 1024 32 256 256 32 32 none -u s32 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -u s8 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -u s32 r n n n P 1024 64 512 512 64 64 none -u s8 r n n n P 1024 64 512 512 64 64 none -u s32 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -u s8 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -u s32 r n n n P 64 800 320 320 800 800 none -u s8 r n n n P 64 800 320 320 800 800 none -u s32 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip -u s8 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip -u s32 r n n n P 64 768 512 512 768 768 none -u s8 r n n n P 64 768 512 512 768 768 none -u s32 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip -u s8 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip -u s32 r n n n P 16 256 512 512 256 256 none -u s8 r n n n P 16 256 512 512 256 256 none -u s32 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip -u s8 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip -u s32 r n n n P 128 128 128 128 128 128 none -u s8 r n n n P 128 128 128 128 128 128 none -u s32 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip -u s8 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip -u s32 r n n n P 256 512 256 256 512 512 none -u s8 r n n n P 256 512 256 256 512 512 none -u s32 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip -u s8 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip -u s32 r n n n P 1024 1024 1024 1024 1024 1024 none -u s8 r n n n P 1024 1024 1024 1024 1024 1024 none -u s32 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -u s8 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -u s32 r n n n P 480 640 1024 1024 640 640 none -u s8 r n n n P 480 640 1024 1024 640 640 none -u s32 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -u s8 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -u s32 r n n n P 480 640 256 256 640 640 none -u s8 r n n n P 480 640 256 256 640 640 none -u s32 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip -u s8 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip -u s32 r n n n P 8 64 32 32 64 64 none -u s8 r n n n P 8 64 32 32 64 64 none -u s32 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip -u s8 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip -u s32 r n n n P 9 64 32 32 64 64 none -u s8 r n n n P 9 64 32 32 64 64 none -u s32 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip -u s8 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip -u s32 r n n n P 10 128 64 64 128 128 none -u s8 r n n n P 10 128 64 64 128 128 none -u s32 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip -u s8 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip -u s32 r n n n P 8 8 8 8 8 8 none -u s8 r n n n P 8 8 8 8 8 8 none -u s32 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip -u s8 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip -u s32 r n n n P 12 12 12 12 12 12 none -u s8 r n n n P 12 12 12 12 12 12 none -u s32 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip -u s8 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip -u s32 r n n n P 25 25 25 25 25 25 none -u s8 r n n n P 25 25 25 25 25 25 none -u s32 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip -u s8 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip -u s32 r n n n P 25 25 20 20 25 25 none -u s8 r n n n P 25 25 20 20 25 25 none -u s32 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip -u s8 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip -u s32 r n n n r 4096 256 5 5 256 256 none -u s8 r n n n r 4096 256 5 5 256 256 none -u s32 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip -u s8 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip -u s32 r n n n r 3000 256 128 128 256 256 none -u s8 r n n n r 3000 256 128 128 256 256 none -u s32 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip -u s8 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip -u s32 r n n n r 4096 1024 512 512 1024 1024 none -u s8 r n n n r 4096 1024 512 512 1024 1024 none -u s32 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip -u s8 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip -u s32 r n n n r 144 256 5 5 256 256 none -u s8 r n n n r 144 256 5 5 256 256 none -u s32 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip -u s8 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip -u s32 r n n n r 144 256 128 128 256 256 none -u s8 r n n n r 144 256 128 128 256 256 none -u s32 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip -u s8 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip -u s32 r n n n r 144 1024 512 512 1024 1024 none -u s8 r n n n r 144 1024 512 512 1024 1024 none -u s32 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip -u s8 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip -u s32 r n n n r 480 688 256 256 688 688 none -u s8 r n n n r 480 688 256 256 688 688 none -u s32 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip -u s8 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip -u s32 r n n n r 480 640 512 512 640 640 none -u s8 r n n n r 480 640 512 512 640 640 none -u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -u s32 r n n n r 480 640 1024 1024 640 640 none -u s8 r n n n r 480 640 1024 1024 640 640 none -u s32 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip -u s8 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip -u s32 r n n n r 64 800 320 320 800 800 none -u s8 r n n n r 64 800 320 320 800 800 none -u s32 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip -u s8 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip -u s32 r n n n r 64 768 512 512 768 768 none -u s8 r n n n r 64 768 512 512 768 768 none -u s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -u s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -u s32 r n n n r 16 256 512 512 256 256 none -u s8 r n n n r 16 256 512 512 256 256 none -u s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -u s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -u s32 r n n n r 128 128 128 128 128 128 none -u s8 r n n n r 128 128 128 128 128 128 none -u s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -u s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -u s32 r n n n r 256 512 256 256 512 512 none -u s8 r n n n r 256 512 256 256 512 512 none -u s32 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip -u s8 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip -u s32 r n n n r 1024 1024 1024 1024 1024 1024 none -u s8 r n n n r 1024 1024 1024 1024 1024 1024 none -u s32 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip -u s8 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip -u s32 r n n n r 1024 32 256 256 32 32 none -u s8 r n n n r 1024 32 256 256 32 32 none -u s32 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip -u s8 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip -u s32 r n n n r 1024 64 512 512 64 64 none -u s8 r n n n r 1024 64 512 512 64 64 none -u s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -u s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -u s32 r n n n r 1024 256 32 32 256 256 none -u s8 r n n n r 1024 256 32 32 256 256 none -u s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -u s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -u s32 r n n n r 1024 512 64 64 512 512 none -u s8 r n n n r 1024 512 64 64 512 512 none -u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -u s32 r n n n r 512 32 256 256 32 32 none -u s8 r n n n r 512 32 256 256 32 32 none -u s32 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip -u s8 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip -u s32 r n n n r 512 768 512 512 768 768 none -u s8 r n n n r 512 768 512 512 768 768 none -u s32 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip -u s8 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip -u s32 r n n n r 512 256 32 32 256 256 none -u s8 r n n n r 512 256 32 32 256 256 none -u s32 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip -u s8 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip -u s32 r n n n r 512 512 64 64 512 512 none -u s8 r n n n r 512 512 64 64 512 512 none -u s32 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip -u s8 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip -u s32 r n n n r 512 256 768 768 256 256 none -u s8 r n n n r 512 256 768 768 256 256 none -u s32 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip -u s8 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip -u s32 r n n n r 768 768 1024 1024 768 768 none -u s8 r n n n r 768 768 1024 1024 768 768 none -u s32 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip -u s8 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip -u s32 r n n n r 768 768 768 768 768 768 none -u s8 r n n n r 768 768 768 768 768 768 none -u s32 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip -u s8 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip -u s32 r n n n r 2048 2048 2048 2048 2048 2048 none -u s8 r n n n r 2048 2048 2048 2048 2048 2048 none -u s32 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip -u s8 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip -u s32 r n n n r 4096 4096 4096 4096 4096 4096 none -u s8 r n n n r 4096 4096 4096 4096 4096 4096 none -u s32 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip -u s8 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip -f f32 c n n n p 2482 1127 2050 2482 2050 2482 none -f f32 f32 c n n n p 2482 1127 2050 2482 2050 2482 bias,relu,clip -f f32 c n n n p 2483 1127 2050 2483 2050 2483 none -f f32 f32 c n n n p 2483 1127 2050 2483 2050 2483 bias,relu,clip -f f32 c n n n p 2484 1127 2050 2484 2050 2484 none -f f32 f32 c n n n p 2484 1127 2050 2484 2050 2484 bias,relu,clip -f f32 c n n n p 2485 1127 2050 2485 2050 2485 none -f f32 f32 c n n n p 2485 1127 2050 2485 2050 2485 bias,relu,clip -f f32 c n n n p 480 1138 2050 480 2050 480 none -f f32 f32 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip -f f32 c n n n p 481 1138 2050 481 2050 481 none -f f32 f32 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip -f f32 c n n n p 482 1138 2050 482 2050 482 none -f f32 f32 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip -f f32 c n n n p 483 1138 2050 483 2050 483 none -f f32 f32 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip -f f32 c n n n p 484 1138 2050 484 2050 484 none -f f32 f32 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip -f f32 c n n n p 485 1138 2050 485 2050 485 none -f f32 f32 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip -f f32 c n n n p 1 1 3 1 3 1 none -f f32 f32 c n n n p 1 1 3 1 3 1 bias,relu,clip -f f32 c n n n p 1 9 3 1 3 1 none -f f32 f32 c n n n p 1 9 3 1 3 1 bias,relu,clip -f f32 c n n n p 1 2048 3 1 3 1 none -f f32 f32 c n n n p 1 2048 3 1 3 1 bias,relu,clip -f f32 c n n n p 1 2048 5192 1 5192 1 none -f f32 f32 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip -f f32 c n n n p 9 1 3 9 3 9 none -f f32 f32 c n n n p 9 1 3 9 3 9 bias,relu,clip -f f32 c n n n p 576 1 3500 576 3500 576 none -f f32 f32 c n n n p 576 1 3500 576 3500 576 bias,relu,clip -f f32 c n n n p 1 1 1 1 1 1 none -f f32 f32 c n n n p 1 1 1 1 1 1 bias,relu,clip -f f32 c n n n p 102 1088 1024 102 1024 102 none -f f32 f32 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip -b f32 r n n n r 480 20 2050 2050 20 20 none -b bf16 r n n n r 480 20 2050 2050 20 20 none -b f32 bf16 r n n n r 480 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n r 480 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n r 481 20 2050 2050 20 20 none -b bf16 r n n n r 481 20 2050 2050 20 20 none -b f32 bf16 r n n n r 481 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n r 481 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n r 482 20 2050 2050 20 20 none -b bf16 r n n n r 482 20 2050 2050 20 20 none -b f32 bf16 r n n n r 482 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n r 482 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n p 483 20 2050 2050 20 20 none -b bf16 r n n n p 483 20 2050 2050 20 20 none -b f32 bf16 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n R 484 20 2050 2050 20 20 none -b bf16 r n n n R 484 20 2050 2050 20 20 none -b f32 bf16 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n R 485 20 2050 2050 20 20 none -b bf16 r n n n R 485 20 2050 2050 20 20 none -b f32 bf16 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -b bf16 bf16 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -b f32 r n n n R 480 39 2050 2050 39 39 none -b bf16 r n n n R 480 39 2050 2050 39 39 none -b f32 bf16 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n R 481 39 2050 2050 39 39 none -b bf16 r n n n R 481 39 2050 2050 39 39 none -b f32 bf16 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n R 482 39 2050 2050 39 39 none -b bf16 r n n n R 482 39 2050 2050 39 39 none -b f32 bf16 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n R 483 39 2050 2050 39 39 none -b bf16 r n n n R 483 39 2050 2050 39 39 none -b f32 bf16 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n R 484 39 2050 2050 39 39 none -b bf16 r n n n R 484 39 2050 2050 39 39 none -b f32 bf16 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n p 485 39 2050 2050 39 39 none -b bf16 r n n n p 485 39 2050 2050 39 39 none -b f32 bf16 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -b bf16 bf16 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -b f32 r n n n p 480 50 2050 2050 50 50 none -b bf16 r n n n p 480 50 2050 2050 50 50 none -b f32 bf16 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n p 481 50 2050 2050 50 50 none -b bf16 r n n n p 481 50 2050 2050 50 50 none -b f32 bf16 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n p 482 50 2050 2050 50 50 none -b bf16 r n n n p 482 50 2050 2050 50 50 none -b f32 bf16 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n p 483 50 2050 2050 50 50 none -b bf16 r n n n p 483 50 2050 2050 50 50 none -b f32 bf16 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n p 484 50 2050 2050 50 50 none -b bf16 r n n n p 484 50 2050 2050 50 50 none -b f32 bf16 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n p 485 50 2050 2050 50 50 none -b bf16 r n n n p 485 50 2050 2050 50 50 none -b f32 bf16 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -b bf16 bf16 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -b f32 r n n n R 480 1108 2050 2050 1108 1108 none -b bf16 r n n n R 480 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 481 1108 2050 2050 1108 1108 none -b bf16 r n n n R 481 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 482 1108 2050 2050 1108 1108 none -b bf16 r n n n R 482 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 483 1108 2050 2050 1108 1108 none -b bf16 r n n n R 483 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 484 1108 2050 2050 1108 1108 none -b bf16 r n n n R 484 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 485 1108 2050 2050 1108 1108 none -b bf16 r n n n R 485 1108 2050 2050 1108 1108 none -b f32 bf16 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -b bf16 bf16 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -b f32 r n n n R 480 1127 2050 2050 1127 1127 none -b bf16 r n n n R 480 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n R 481 1127 2050 2050 1127 1127 none -b bf16 r n n n R 481 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n R 482 1127 2050 2050 1127 1127 none -b bf16 r n n n R 482 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n R 483 1127 2050 2050 1127 1127 none -b bf16 r n n n R 483 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n p 484 1127 2050 2050 1127 1127 none -b bf16 r n n n p 484 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n p 485 1127 2050 2050 1127 1127 none -b bf16 r n n n p 485 1127 2050 2050 1127 1127 none -b f32 bf16 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -b bf16 bf16 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -b f32 r n n n p 480 1138 2050 2050 1138 1138 none -b bf16 r n n n p 480 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 481 1138 2050 2050 1138 1138 none -b bf16 r n n n p 481 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 482 1138 2050 2050 1138 1138 none -b bf16 r n n n p 482 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 483 1138 2050 2050 1138 1138 none -b bf16 r n n n p 483 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 484 1138 2050 2050 1138 1138 none -b bf16 r n n n p 484 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 485 1138 2050 2050 1138 1138 none -b bf16 r n n n p 485 1138 2050 2050 1138 1138 none -b f32 bf16 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -b bf16 bf16 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -b f32 r n n n p 1 1 3 3 1 1 none -b bf16 r n n n p 1 1 3 3 1 1 none -b f32 bf16 r n n n p 1 1 3 3 1 1 bias,relu,clip -b bf16 bf16 r n n n p 1 1 3 3 1 1 bias,relu,clip -b f32 r n n n p 1 9 3 3 9 9 none -b bf16 r n n n p 1 9 3 3 9 9 none -b f32 bf16 r n n n p 1 9 3 3 9 9 bias,relu,clip -b bf16 bf16 r n n n p 1 9 3 3 9 9 bias,relu,clip -b f32 r n n n p 1 2048 3 3 2048 2048 none -b bf16 r n n n p 1 2048 3 3 2048 2048 none -b f32 bf16 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -b f32 r n n n p 1 2048 5192 5192 2048 2048 none -b bf16 r n n n p 1 2048 5192 5192 2048 2048 none -b f32 bf16 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -b f32 r n n n p 9 1 3 3 1 1 none -b bf16 r n n n p 9 1 3 3 1 1 none -b f32 bf16 r n n n p 9 1 3 3 1 1 bias,relu,clip -b bf16 bf16 r n n n p 9 1 3 3 1 1 bias,relu,clip -b f32 r n n n p 576 1 3500 3500 1 1 none -b bf16 r n n n p 576 1 3500 3500 1 1 none -b f32 bf16 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -b bf16 bf16 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -b f32 r n n n p 1 1 1 1 1 1 none -b bf16 r n n n p 1 1 1 1 1 1 none -b f32 bf16 r n n n p 1 1 1 1 1 1 bias,relu,clip -b bf16 bf16 r n n n p 1 1 1 1 1 1 bias,relu,clip -b f32 r n n n p 102 1088 1024 1024 1088 1088 none -b bf16 r n n n p 102 1088 1024 1024 1088 1088 none -b f32 bf16 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -b bf16 bf16 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -b f32 r n n n p 102 2048 1024 1024 2048 2048 none -b bf16 r n n n p 102 2048 1024 1024 2048 2048 none -b f32 bf16 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -b f32 r n n n p 485 656 1024 1024 656 656 none -b bf16 r n n n p 485 656 1024 1024 656 656 none -b f32 bf16 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -b bf16 bf16 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -b f32 r n n n p 483 656 1024 1024 656 656 none -b bf16 r n n n p 483 656 1024 1024 656 656 none -b f32 bf16 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -b bf16 bf16 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -b f32 r n n n p 81 128 3 3 128 128 none -b bf16 r n n n p 81 128 3 3 128 128 none -b f32 bf16 r n n n p 81 128 3 3 128 128 bias,relu,clip -b bf16 bf16 r n n n p 81 128 3 3 128 128 bias,relu,clip -b f32 r n n n p 1022 512 515 515 512 512 none -b bf16 r n n n p 1022 512 515 515 512 512 none -b f32 bf16 r n n n p 1022 512 515 515 512 512 bias,relu,clip -b bf16 bf16 r n n n p 1022 512 515 515 512 512 bias,relu,clip -b f32 r n n n p 74 512 515 515 512 512 none -b bf16 r n n n p 74 512 515 515 512 512 none -b f32 bf16 r n n n p 74 512 515 515 512 512 bias,relu,clip -b bf16 bf16 r n n n p 74 512 515 515 512 512 bias,relu,clip -b f32 r n n n p 253 2048 515 515 2048 2048 none -b bf16 r n n n p 253 2048 515 515 2048 2048 none -b f32 bf16 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -b f32 r n n n p 8192 1040 515 515 1040 1040 none -b bf16 r n n n p 8192 1040 515 515 1040 1040 none -b f32 bf16 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -b bf16 bf16 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -b f32 r n n n p 10 1029 515 515 1029 1029 none -b bf16 r n n n p 10 1029 515 515 1029 1029 none -b f32 bf16 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -b bf16 bf16 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -b f32 r n n n p 24 1040 2050 2050 1040 1040 none -b bf16 r n n n p 24 1040 2050 2050 1040 1040 none -b f32 bf16 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -b bf16 bf16 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -b f32 r n n n p 1024 1029 2050 2050 1029 1029 none -b bf16 r n n n p 1024 1029 2050 2050 1029 1029 none -b f32 bf16 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -b bf16 bf16 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -b f32 r n n n p 480 660 2050 2050 660 660 none -b bf16 r n n n p 480 660 2050 2050 660 660 none -b f32 bf16 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 481 660 2050 2050 660 660 none -b bf16 r n n n p 481 660 2050 2050 660 660 none -b f32 bf16 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 482 660 2050 2050 660 660 none -b bf16 r n n n p 482 660 2050 2050 660 660 none -b f32 bf16 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 483 660 2050 2050 660 660 none -b bf16 r n n n p 483 660 2050 2050 660 660 none -b f32 bf16 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 484 660 2050 2050 660 660 none -b bf16 r n n n p 484 660 2050 2050 660 660 none -b f32 bf16 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 485 660 2050 2050 660 660 none -b bf16 r n n n p 485 660 2050 2050 660 660 none -b f32 bf16 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -b bf16 bf16 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -b f32 r n n n p 480 679 2050 2050 679 679 none -b bf16 r n n n p 480 679 2050 2050 679 679 none -b f32 bf16 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 481 679 2050 2050 679 679 none -b bf16 r n n n p 481 679 2050 2050 679 679 none -b f32 bf16 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 482 679 2050 2050 679 679 none -b bf16 r n n n p 482 679 2050 2050 679 679 none -b f32 bf16 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 483 679 2050 2050 679 679 none -b bf16 r n n n p 483 679 2050 2050 679 679 none -b f32 bf16 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 484 679 2050 2050 679 679 none -b bf16 r n n n p 484 679 2050 2050 679 679 none -b f32 bf16 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 485 679 2050 2050 679 679 none -b bf16 r n n n p 485 679 2050 2050 679 679 none -b f32 bf16 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -b bf16 bf16 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -b f32 r n n n p 480 690 2050 2050 690 690 none -b bf16 r n n n p 480 690 2050 2050 690 690 none -b f32 bf16 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 481 690 2050 2050 690 690 none -b bf16 r n n n p 481 690 2050 2050 690 690 none -b f32 bf16 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 482 690 2050 2050 690 690 none -b bf16 r n n n p 482 690 2050 2050 690 690 none -b f32 bf16 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 483 690 2050 2050 690 690 none -b bf16 r n n n p 483 690 2050 2050 690 690 none -b f32 bf16 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 484 690 2050 2050 690 690 none -b bf16 r n n n p 484 690 2050 2050 690 690 none -b f32 bf16 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 485 690 2050 2050 690 690 none -b bf16 r n n n p 485 690 2050 2050 690 690 none -b f32 bf16 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -b bf16 bf16 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -b f32 r n n n p 480 660 2048 2048 660 660 none -b bf16 r n n n p 480 660 2048 2048 660 660 none -b f32 bf16 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 481 660 2048 2048 660 660 none -b bf16 r n n n p 481 660 2048 2048 660 660 none -b f32 bf16 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 482 660 2048 2048 660 660 none -b bf16 r n n n p 482 660 2048 2048 660 660 none -b f32 bf16 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 483 660 2048 2048 660 660 none -b bf16 r n n n p 483 660 2048 2048 660 660 none -b f32 bf16 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 484 660 2048 2048 660 660 none -b bf16 r n n n p 484 660 2048 2048 660 660 none -b f32 bf16 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 485 660 2048 2048 660 660 none -b bf16 r n n n p 485 660 2048 2048 660 660 none -b f32 bf16 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -b bf16 bf16 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -b f32 r n n n p 480 679 2048 2048 679 679 none -b bf16 r n n n p 480 679 2048 2048 679 679 none -b f32 bf16 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 481 679 2048 2048 679 679 none -b bf16 r n n n p 481 679 2048 2048 679 679 none -b f32 bf16 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 482 679 2048 2048 679 679 none -b bf16 r n n n p 482 679 2048 2048 679 679 none -b f32 bf16 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 483 679 2048 2048 679 679 none -b bf16 r n n n p 483 679 2048 2048 679 679 none -b f32 bf16 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 484 679 2048 2048 679 679 none -b bf16 r n n n p 484 679 2048 2048 679 679 none -b f32 bf16 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 485 679 2048 2048 679 679 none -b bf16 r n n n p 485 679 2048 2048 679 679 none -b f32 bf16 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -b bf16 bf16 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -b f32 r n n n p 480 690 2048 2048 690 690 none -b bf16 r n n n p 480 690 2048 2048 690 690 none -b f32 bf16 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 481 690 2048 2048 690 690 none -b bf16 r n n n p 481 690 2048 2048 690 690 none -b f32 bf16 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 482 690 2048 2048 690 690 none -b bf16 r n n n p 482 690 2048 2048 690 690 none -b f32 bf16 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 483 690 2048 2048 690 690 none -b bf16 r n n n p 483 690 2048 2048 690 690 none -b f32 bf16 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 484 690 2048 2048 690 690 none -b bf16 r n n n p 484 690 2048 2048 690 690 none -b f32 bf16 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 485 690 2048 2048 690 690 none -b bf16 r n n n p 485 690 2048 2048 690 690 none -b f32 bf16 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -b bf16 bf16 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -b f32 r n n n p 480 656 1024 1024 656 656 none -b bf16 r n n n p 480 656 1024 1024 656 656 none -b f32 bf16 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -b bf16 bf16 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -b f32 r n n n p 480 128 3 3 128 128 none -b bf16 r n n n p 480 128 3 3 128 128 none -b f32 bf16 r n n n p 480 128 3 3 128 128 bias,relu,clip -b bf16 bf16 r n n n p 480 128 3 3 128 128 bias,relu,clip -b f32 r n n n p 1024 512 515 515 512 512 none -b bf16 r n n n p 1024 512 515 515 512 512 none -b f32 bf16 r n n n p 1024 512 515 515 512 512 bias,relu,clip -b bf16 bf16 r n n n p 1024 512 515 515 512 512 bias,relu,clip -b f32 r n n n p 1024 2048 1024 1024 2048 2048 none -b bf16 r n n n p 1024 2048 1024 1024 2048 2048 none -b f32 bf16 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -b f32 r n n n p 1024 2048 515 515 2048 2048 none -b bf16 r n n n p 1024 2048 515 515 2048 2048 none -b f32 bf16 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -b bf16 bf16 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -b f32 r n n n p 1024 1040 515 515 1040 1040 none -b bf16 r n n n p 1024 1040 515 515 1040 1040 none -b f32 bf16 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -b bf16 bf16 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -b f32 r n n n p 5 1029 515 515 1029 1029 none -b bf16 r n n n p 5 1029 515 515 1029 1029 none -b f32 bf16 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -b bf16 bf16 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -b f32 r n n n p 1024 1029 515 515 1029 1029 none -b bf16 r n n n p 1024 1029 515 515 1029 1029 none -b f32 bf16 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -b bf16 bf16 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -b f32 r n n n p 1024 1040 2050 2050 1040 1040 none -b bf16 r n n n p 1024 1040 2050 2050 1040 1040 none -b f32 bf16 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -b bf16 bf16 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -b f32 r n n n p 1029 1029 2050 2050 1029 1029 none -b bf16 r n n n p 1029 1029 2050 2050 1029 1029 none -b f32 bf16 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -b bf16 bf16 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -b f32 r n n n R 480 646 2050 2050 646 646 none -b bf16 r n n n R 480 646 2050 2050 646 646 none -b f32 bf16 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 481 646 2050 2050 646 646 none -b bf16 r n n n R 481 646 2050 2050 646 646 none -b f32 bf16 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 482 646 2050 2050 646 646 none -b bf16 r n n n R 482 646 2050 2050 646 646 none -b f32 bf16 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 483 646 2050 2050 646 646 none -b bf16 r n n n R 483 646 2050 2050 646 646 none -b f32 bf16 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 484 646 2050 2050 646 646 none -b bf16 r n n n R 484 646 2050 2050 646 646 none -b f32 bf16 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 485 646 2050 2050 646 646 none -b bf16 r n n n R 485 646 2050 2050 646 646 none -b f32 bf16 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -b bf16 bf16 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -b f32 r n n n R 481 656 2050 2050 656 656 none -b bf16 r n n n R 481 656 2050 2050 656 656 none -b f32 bf16 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -b bf16 bf16 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -b f32 r n n n R 482 656 2050 2050 656 656 none -b bf16 r n n n R 482 656 2050 2050 656 656 none -b f32 bf16 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -b bf16 bf16 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -b f32 r n n n R 483 656 2050 2050 656 656 none -b bf16 r n n n R 483 656 2050 2050 656 656 none -b f32 bf16 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -b bf16 bf16 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -b f32 r n n n R 484 656 2050 2050 656 656 none -b bf16 r n n n R 484 656 2050 2050 656 656 none -b f32 bf16 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -b bf16 bf16 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -b f32 r n n n p 485 656 2050 2050 656 656 none -b bf16 r n n n p 485 656 2050 2050 656 656 none -b f32 bf16 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -b bf16 bf16 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -b f32 r n n n p 480 672 2050 2050 672 672 none -b bf16 r n n n p 480 672 2050 2050 672 672 none -b f32 bf16 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 481 672 2050 2050 672 672 none -b bf16 r n n n p 481 672 2050 2050 672 672 none -b f32 bf16 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 482 672 2050 2050 672 672 none -b bf16 r n n n p 482 672 2050 2050 672 672 none -b f32 bf16 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 483 672 2050 2050 672 672 none -b bf16 r n n n p 483 672 2050 2050 672 672 none -b f32 bf16 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 484 672 2050 2050 672 672 none -b bf16 r n n n p 484 672 2050 2050 672 672 none -b f32 bf16 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 485 672 2050 2050 672 672 none -b bf16 r n n n p 485 672 2050 2050 672 672 none -b f32 bf16 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -b bf16 bf16 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -b f32 r n n n p 480 688 2050 2050 688 688 none -b bf16 r n n n p 480 688 2050 2050 688 688 none -b f32 bf16 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n p 481 688 2050 2050 688 688 none -b bf16 r n n n p 481 688 2050 2050 688 688 none -b f32 bf16 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n r 482 688 2050 2050 688 688 none -b bf16 r n n n r 482 688 2050 2050 688 688 none -b f32 bf16 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n r 483 688 2050 2050 688 688 none -b bf16 r n n n r 483 688 2050 2050 688 688 none -b f32 bf16 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n r 484 688 2050 2050 688 688 none -b bf16 r n n n r 484 688 2050 2050 688 688 none -b f32 bf16 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n r 485 688 2050 2050 688 688 none -b bf16 r n n n r 485 688 2050 2050 688 688 none -b f32 bf16 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -b bf16 bf16 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -b f32 r n n n r 1024 512 64 64 512 512 none -b bf16 r n n n r 1024 512 64 64 512 512 none -b f32 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip -b bf16 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip -b f32 r n n n r 16 256 512 512 256 256 none -b bf16 r n n n r 16 256 512 512 256 256 none -b f32 bf16 r n n n r 16 256 512 512 256 256 bias,relu,clip -b bf16 bf16 r n n n r 16 256 512 512 256 256 bias,relu,clip -b f32 r n n n r 480 640 512 512 640 640 none -b bf16 r n n n r 480 640 512 512 640 640 none -b f32 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip -b bf16 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip -b f32 r n n n r 64 768 512 512 768 768 none -b bf16 r n n n r 64 768 512 512 768 768 none -b f32 bf16 r n n n r 64 768 512 512 768 768 bias,relu,clip -b bf16 bf16 r n n n r 64 768 512 512 768 768 bias,relu,clip -b f32 r n n n r 128 128 128 128 128 128 none -b bf16 r n n n r 128 128 128 128 128 128 none -b f32 bf16 r n n n r 128 128 128 128 128 128 bias,relu,clip -b bf16 bf16 r n n n r 128 128 128 128 128 128 bias,relu,clip -b f32 r n n n r 1024 64 512 512 64 64 none -b bf16 r n n n r 1024 64 512 512 64 64 none -b f32 bf16 r n n n r 1024 64 512 512 64 64 bias,relu,clip -b bf16 bf16 r n n n r 1024 64 512 512 64 64 bias,relu,clip -b f32 r n n n r 1024 256 32 32 256 256 none -b bf16 r n n n r 1024 256 32 32 256 256 none -b f32 bf16 r n n n r 1024 256 32 32 256 256 bias,relu,clip -b bf16 bf16 r n n n r 1024 256 32 32 256 256 bias,relu,clip -b f32 r n n n r 1024 512 64 64 512 512 none -b bf16 r n n n r 1024 512 64 64 512 512 none -b f32 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip -b bf16 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip -b f32 r n n n r 480 640 512 512 640 640 none -b bf16 r n n n r 480 640 512 512 640 640 none -b f32 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip -b bf16 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip -b f32 r n n n p 1024 32 256 256 32 32 none -b bf16 r n n n p 1024 32 256 256 32 32 none -b f32 bf16 r n n n p 1024 32 256 256 32 32 bias,relu,clip -b bf16 bf16 r n n n p 1024 32 256 256 32 32 bias,relu,clip -b f32 r n n n P 1024 64 512 512 64 64 none -b bf16 r n n n P 1024 64 512 512 64 64 none -b f32 bf16 r n n n P 1024 64 512 512 64 64 bias,relu,clip -b bf16 bf16 r n n n P 1024 64 512 512 64 64 bias,relu,clip -b f32 r n n n P 64 800 320 320 800 800 none -b bf16 r n n n P 64 800 320 320 800 800 none -b f32 bf16 r n n n P 64 800 320 320 800 800 bias,relu,clip -b bf16 bf16 r n n n P 64 800 320 320 800 800 bias,relu,clip -b f32 r n n n P 64 768 512 512 768 768 none -b bf16 r n n n P 64 768 512 512 768 768 none -b f32 bf16 r n n n P 64 768 512 512 768 768 bias,relu,clip -b bf16 bf16 r n n n P 64 768 512 512 768 768 bias,relu,clip -b f32 r n n n P 16 256 512 512 256 256 none -b bf16 r n n n P 16 256 512 512 256 256 none -b f32 bf16 r n n n P 16 256 512 512 256 256 bias,relu,clip -b bf16 bf16 r n n n P 16 256 512 512 256 256 bias,relu,clip -b f32 r n n n P 128 128 128 128 128 128 none -b bf16 r n n n P 128 128 128 128 128 128 none -b f32 bf16 r n n n P 128 128 128 128 128 128 bias,relu,clip -b bf16 bf16 r n n n P 128 128 128 128 128 128 bias,relu,clip -b f32 r n n n P 256 512 256 256 512 512 none -b bf16 r n n n P 256 512 256 256 512 512 none -b f32 bf16 r n n n P 256 512 256 256 512 512 bias,relu,clip -b bf16 bf16 r n n n P 256 512 256 256 512 512 bias,relu,clip -b f32 r n n n P 1024 1024 1024 1024 1024 1024 none -b bf16 r n n n P 1024 1024 1024 1024 1024 1024 none -b f32 bf16 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -b bf16 bf16 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -b f32 r n n n P 480 640 1024 1024 640 640 none -b bf16 r n n n P 480 640 1024 1024 640 640 none -b f32 bf16 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -b bf16 bf16 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -b f32 r n n n P 480 640 256 256 640 640 none -b bf16 r n n n P 480 640 256 256 640 640 none -b f32 bf16 r n n n P 480 640 256 256 640 640 bias,relu,clip -b bf16 bf16 r n n n P 480 640 256 256 640 640 bias,relu,clip -b f32 r n n n P 8 64 32 32 64 64 none -b bf16 r n n n P 8 64 32 32 64 64 none -b f32 bf16 r n n n P 8 64 32 32 64 64 bias,relu,clip -b bf16 bf16 r n n n P 8 64 32 32 64 64 bias,relu,clip -b f32 r n n n P 9 64 32 32 64 64 none -b bf16 r n n n P 9 64 32 32 64 64 none -b f32 bf16 r n n n P 9 64 32 32 64 64 bias,relu,clip -b bf16 bf16 r n n n P 9 64 32 32 64 64 bias,relu,clip -b f32 r n n n P 10 128 64 64 128 128 none -b bf16 r n n n P 10 128 64 64 128 128 none -b f32 bf16 r n n n P 10 128 64 64 128 128 bias,relu,clip -b bf16 bf16 r n n n P 10 128 64 64 128 128 bias,relu,clip -b f32 r n n n P 8 8 8 8 8 8 none -b bf16 r n n n P 8 8 8 8 8 8 none -b f32 bf16 r n n n P 8 8 8 8 8 8 bias,relu,clip -b bf16 bf16 r n n n P 8 8 8 8 8 8 bias,relu,clip -b f32 r n n n P 12 12 12 12 12 12 none -b bf16 r n n n P 12 12 12 12 12 12 none -b f32 bf16 r n n n P 12 12 12 12 12 12 bias,relu,clip -b bf16 bf16 r n n n P 12 12 12 12 12 12 bias,relu,clip -b f32 r n n n P 25 25 25 25 25 25 none -b bf16 r n n n P 25 25 25 25 25 25 none -b f32 bf16 r n n n P 25 25 25 25 25 25 bias,relu,clip -b bf16 bf16 r n n n P 25 25 25 25 25 25 bias,relu,clip -b f32 r n n n P 25 25 20 20 25 25 none -b bf16 r n n n P 25 25 20 20 25 25 none -b f32 bf16 r n n n P 25 25 20 20 25 25 bias,relu,clip -b bf16 bf16 r n n n P 25 25 20 20 25 25 bias,relu,clip -b f32 c n n n p 485 39 2050 485 2050 485 none -b bf16 c n n n p 485 39 2050 485 2050 485 none -b f32 bf16 c n n n p 485 39 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 39 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 480 50 2050 480 2050 480 none -b bf16 c n n n p 480 50 2050 480 2050 480 none -b f32 bf16 c n n n p 480 50 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c n n n p 480 50 2050 480 2050 480 bias,relu,clip -b f32 c n n n p 481 50 2050 481 2050 481 none -b bf16 c n n n p 481 50 2050 481 2050 481 none -b f32 bf16 c n n n p 481 50 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c n n n p 481 50 2050 481 2050 481 bias,relu,clip -b f32 c n n n p 482 50 2050 482 2050 482 none -b bf16 c n n n p 482 50 2050 482 2050 482 none -b f32 bf16 c n n n p 482 50 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c n n n p 482 50 2050 482 2050 482 bias,relu,clip -b f32 c n n n p 483 50 2050 483 2050 483 none -b bf16 c n n n p 483 50 2050 483 2050 483 none -b f32 bf16 c n n n p 483 50 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c n n n p 483 50 2050 483 2050 483 bias,relu,clip -b f32 c n n n p 484 50 2050 484 2050 484 none -b bf16 c n n n p 484 50 2050 484 2050 484 none -b f32 bf16 c n n n p 484 50 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 50 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 50 2050 485 2050 485 none -b bf16 c n n n p 485 50 2050 485 2050 485 none -b f32 bf16 c n n n p 485 50 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 50 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 484 1127 2050 484 2050 484 none -b bf16 c n n n p 484 1127 2050 484 2050 484 none -b f32 bf16 c n n n p 484 1127 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 1127 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 1127 2050 485 2050 485 none -b bf16 c n n n p 485 1127 2050 485 2050 485 none -b f32 bf16 c n n n p 485 1127 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 1127 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 480 1138 2050 480 2050 480 none -b bf16 c n n n p 480 1138 2050 480 2050 480 none -b f32 bf16 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip -b f32 c n n n p 481 1138 2050 481 2050 481 none -b bf16 c n n n p 481 1138 2050 481 2050 481 none -b f32 bf16 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip -b f32 c n n n p 482 1138 2050 482 2050 482 none -b bf16 c n n n p 482 1138 2050 482 2050 482 none -b f32 bf16 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip -b f32 c n n n p 483 1138 2050 483 2050 483 none -b bf16 c n n n p 483 1138 2050 483 2050 483 none -b f32 bf16 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip -b f32 c n n n p 484 1138 2050 484 2050 484 none -b bf16 c n n n p 484 1138 2050 484 2050 484 none -b f32 bf16 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 1138 2050 485 2050 485 none -b bf16 c n n n p 485 1138 2050 485 2050 485 none -b f32 bf16 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 1 1 3 1 3 1 none -b bf16 c n n n p 1 1 3 1 3 1 none -b f32 bf16 c n n n p 1 1 3 1 3 1 bias,relu,clip -b bf16 bf16 c n n n p 1 1 3 1 3 1 bias,relu,clip -b f32 c n n n p 1 9 3 1 3 1 none -b bf16 c n n n p 1 9 3 1 3 1 none -b f32 bf16 c n n n p 1 9 3 1 3 1 bias,relu,clip -b bf16 bf16 c n n n p 1 9 3 1 3 1 bias,relu,clip -b f32 c n n n p 1 2048 3 1 3 1 none -b bf16 c n n n p 1 2048 3 1 3 1 none -b f32 bf16 c n n n p 1 2048 3 1 3 1 bias,relu,clip -b bf16 bf16 c n n n p 1 2048 3 1 3 1 bias,relu,clip -b f32 c n n n p 1 2048 5192 1 5192 1 none -b bf16 c n n n p 1 2048 5192 1 5192 1 none -b f32 bf16 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip -b bf16 bf16 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip -b f32 c n n n p 9 1 3 9 3 9 none -b bf16 c n n n p 9 1 3 9 3 9 none -b f32 bf16 c n n n p 9 1 3 9 3 9 bias,relu,clip -b bf16 bf16 c n n n p 9 1 3 9 3 9 bias,relu,clip -b f32 c n n n p 576 1 3500 576 3500 576 none -b bf16 c n n n p 576 1 3500 576 3500 576 none -b f32 bf16 c n n n p 576 1 3500 576 3500 576 bias,relu,clip -b bf16 bf16 c n n n p 576 1 3500 576 3500 576 bias,relu,clip -b f32 c n n n p 1 1 1 1 1 1 none -b bf16 c n n n p 1 1 1 1 1 1 none -b f32 bf16 c n n n p 1 1 1 1 1 1 bias,relu,clip -b bf16 bf16 c n n n p 1 1 1 1 1 1 bias,relu,clip -b f32 c n n n p 102 1088 1024 102 1024 102 none -b bf16 c n n n p 102 1088 1024 102 1024 102 none -b f32 bf16 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip -b bf16 bf16 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip -b f32 c n n n p 102 2048 1024 102 1024 102 none -b bf16 c n n n p 102 2048 1024 102 1024 102 none -b f32 bf16 c n n n p 102 2048 1024 102 1024 102 bias,relu,clip -b bf16 bf16 c n n n p 102 2048 1024 102 1024 102 bias,relu,clip -b f32 c n n n p 485 656 1024 485 1024 485 none -b bf16 c n n n p 485 656 1024 485 1024 485 none -b f32 bf16 c n n n p 485 656 1024 485 1024 485 bias,relu,clip -b bf16 bf16 c n n n p 485 656 1024 485 1024 485 bias,relu,clip -b f32 c n n n p 483 656 1024 483 1024 483 none -b bf16 c n n n p 483 656 1024 483 1024 483 none -b f32 bf16 c n n n p 483 656 1024 483 1024 483 bias,relu,clip -b bf16 bf16 c n n n p 483 656 1024 483 1024 483 bias,relu,clip -b f32 c n n n p 81 128 3 81 3 81 none -b bf16 c n n n p 81 128 3 81 3 81 none -b f32 bf16 c n n n p 81 128 3 81 3 81 bias,relu,clip -b bf16 bf16 c n n n p 81 128 3 81 3 81 bias,relu,clip -b f32 c n n n p 1022 512 515 1022 515 1022 none -b bf16 c n n n p 1022 512 515 1022 515 1022 none -b f32 bf16 c n n n p 1022 512 515 1022 515 1022 bias,relu,clip -b bf16 bf16 c n n n p 1022 512 515 1022 515 1022 bias,relu,clip -b f32 c n n n p 74 512 515 74 515 74 none -b bf16 c n n n p 74 512 515 74 515 74 none -b f32 bf16 c n n n p 74 512 515 74 515 74 bias,relu,clip -b bf16 bf16 c n n n p 74 512 515 74 515 74 bias,relu,clip -b f32 c n n n p 253 2048 515 253 515 253 none -b bf16 c n n n p 253 2048 515 253 515 253 none -b f32 bf16 c n n n p 253 2048 515 253 515 253 bias,relu,clip -b bf16 bf16 c n n n p 253 2048 515 253 515 253 bias,relu,clip -b f32 c n n n p 8192 1040 515 8192 515 8192 none -b bf16 c n n n p 8192 1040 515 8192 515 8192 none -b f32 bf16 c n n n p 8192 1040 515 8192 515 8192 bias,relu,clip -b bf16 bf16 c n n n p 8192 1040 515 8192 515 8192 bias,relu,clip -b f32 c n n n p 10 1029 515 10 515 10 none -b bf16 c n n n p 10 1029 515 10 515 10 none -b f32 bf16 c n n n p 10 1029 515 10 515 10 bias,relu,clip -b bf16 bf16 c n n n p 10 1029 515 10 515 10 bias,relu,clip -b f32 c n n n p 24 1040 2050 24 2050 24 none -b bf16 c n n n p 24 1040 2050 24 2050 24 none -b f32 bf16 c n n n p 24 1040 2050 24 2050 24 bias,relu,clip -b bf16 bf16 c n n n p 24 1040 2050 24 2050 24 bias,relu,clip -b f32 c n n n p 1024 1029 2050 1024 2050 1024 none -b bf16 c n n n p 1024 1029 2050 1024 2050 1024 none -b f32 bf16 c n n n p 1024 1029 2050 1024 2050 1024 bias,relu,clip -b bf16 bf16 c n n n p 1024 1029 2050 1024 2050 1024 bias,relu,clip -b f32 c n n n p 480 660 2050 480 2050 480 none -b bf16 c n n n p 480 660 2050 480 2050 480 none -b f32 bf16 c n n n p 480 660 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c n n n p 480 660 2050 480 2050 480 bias,relu,clip -b f32 c n n n p 481 660 2050 481 2050 481 none -b bf16 c n n n p 481 660 2050 481 2050 481 none -b f32 bf16 c n n n p 481 660 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c n n n p 481 660 2050 481 2050 481 bias,relu,clip -b f32 c n n n p 482 660 2050 482 2050 482 none -b bf16 c n n n p 482 660 2050 482 2050 482 none -b f32 bf16 c n n n p 482 660 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c n n n p 482 660 2050 482 2050 482 bias,relu,clip -b f32 c n n n p 483 660 2050 483 2050 483 none -b bf16 c n n n p 483 660 2050 483 2050 483 none -b f32 bf16 c n n n p 483 660 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c n n n p 483 660 2050 483 2050 483 bias,relu,clip -b f32 c n n n p 484 660 2050 484 2050 484 none -b bf16 c n n n p 484 660 2050 484 2050 484 none -b f32 bf16 c n n n p 484 660 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 660 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 660 2050 485 2050 485 none -b bf16 c n n n p 485 660 2050 485 2050 485 none -b f32 bf16 c n n n p 485 660 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 660 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 480 679 2050 480 2050 480 none -b bf16 c n n n p 480 679 2050 480 2050 480 none -b f32 bf16 c n n n p 480 679 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c n n n p 480 679 2050 480 2050 480 bias,relu,clip -b f32 c n n n p 481 679 2050 481 2050 481 none -b bf16 c n n n p 481 679 2050 481 2050 481 none -b f32 bf16 c n n n p 481 679 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c n n n p 481 679 2050 481 2050 481 bias,relu,clip -b f32 c n n n p 482 679 2050 482 2050 482 none -b bf16 c n n n p 482 679 2050 482 2050 482 none -b f32 bf16 c n n n p 482 679 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c n n n p 482 679 2050 482 2050 482 bias,relu,clip -b f32 c n n n p 483 679 2050 483 2050 483 none -b bf16 c n n n p 483 679 2050 483 2050 483 none -b f32 bf16 c n n n p 483 679 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c n n n p 483 679 2050 483 2050 483 bias,relu,clip -b f32 c n n n p 484 679 2050 484 2050 484 none -b bf16 c n n n p 484 679 2050 484 2050 484 none -b f32 bf16 c n n n p 484 679 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 679 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 679 2050 485 2050 485 none -b bf16 c n n n p 485 679 2050 485 2050 485 none -b f32 bf16 c n n n p 485 679 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 679 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 480 690 2050 480 2050 480 none -b bf16 c n n n p 480 690 2050 480 2050 480 none -b f32 bf16 c n n n p 480 690 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c n n n p 480 690 2050 480 2050 480 bias,relu,clip -b f32 c n n n p 481 690 2050 481 2050 481 none -b bf16 c n n n p 481 690 2050 481 2050 481 none -b f32 bf16 c n n n p 481 690 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c n n n p 481 690 2050 481 2050 481 bias,relu,clip -b f32 c n n n p 482 690 2050 482 2050 482 none -b bf16 c n n n p 482 690 2050 482 2050 482 none -b f32 bf16 c n n n p 482 690 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c n n n p 482 690 2050 482 2050 482 bias,relu,clip -b f32 c n n n p 483 690 2050 483 2050 483 none -b bf16 c n n n p 483 690 2050 483 2050 483 none -b f32 bf16 c n n n p 483 690 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c n n n p 483 690 2050 483 2050 483 bias,relu,clip -b f32 c n n n p 484 690 2050 484 2050 484 none -b bf16 c n n n p 484 690 2050 484 2050 484 none -b f32 bf16 c n n n p 484 690 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c n n n p 484 690 2050 484 2050 484 bias,relu,clip -b f32 c n n n p 485 690 2050 485 2050 485 none -b bf16 c n n n p 485 690 2050 485 2050 485 none -b f32 bf16 c n n n p 485 690 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c n n n p 485 690 2050 485 2050 485 bias,relu,clip -b f32 c n n n p 480 660 2048 480 2048 480 none -b bf16 c n n n p 480 660 2048 480 2048 480 none -b f32 bf16 c n n n p 480 660 2048 480 2048 480 bias,relu,clip -b bf16 bf16 c n n n p 480 660 2048 480 2048 480 bias,relu,clip -b f32 c n n n p 481 660 2048 481 2048 481 none -b bf16 c n n n p 481 660 2048 481 2048 481 none -b f32 bf16 c n n n p 481 660 2048 481 2048 481 bias,relu,clip -b bf16 bf16 c n n n p 481 660 2048 481 2048 481 bias,relu,clip -b f32 c n n n p 482 660 2048 482 2048 482 none -b bf16 c n n n p 482 660 2048 482 2048 482 none -b f32 bf16 c n n n p 482 660 2048 482 2048 482 bias,relu,clip -b bf16 bf16 c n n n p 482 660 2048 482 2048 482 bias,relu,clip -b f32 c n n n p 483 660 2048 483 2048 483 none -b bf16 c n n n p 483 660 2048 483 2048 483 none -b f32 bf16 c n n n p 483 660 2048 483 2048 483 bias,relu,clip -b bf16 bf16 c n n n p 483 660 2048 483 2048 483 bias,relu,clip -b f32 c n n n p 484 660 2048 484 2048 484 none -b bf16 c n n n p 484 660 2048 484 2048 484 none -b f32 bf16 c n n n p 484 660 2048 484 2048 484 bias,relu,clip -b bf16 bf16 c n n n p 484 660 2048 484 2048 484 bias,relu,clip -b f32 c n n n p 485 660 2048 485 2048 485 none -b bf16 c n n n p 485 660 2048 485 2048 485 none -b f32 bf16 c n n n p 485 660 2048 485 2048 485 bias,relu,clip -b bf16 bf16 c n n n p 485 660 2048 485 2048 485 bias,relu,clip -b f32 c n n n p 480 679 2048 480 2048 480 none -b bf16 c n n n p 480 679 2048 480 2048 480 none -b f32 bf16 c n n n p 480 679 2048 480 2048 480 bias,relu,clip -b bf16 bf16 c n n n p 480 679 2048 480 2048 480 bias,relu,clip -b f32 c n n n p 481 679 2048 481 2048 481 none -b bf16 c n n n p 481 679 2048 481 2048 481 none -b f32 bf16 c n n n p 481 679 2048 481 2048 481 bias,relu,clip -b bf16 bf16 c n n n p 481 679 2048 481 2048 481 bias,relu,clip -b f32 c n n n p 482 679 2048 482 2048 482 none -b bf16 c n n n p 482 679 2048 482 2048 482 none -b f32 bf16 c n n n p 482 679 2048 482 2048 482 bias,relu,clip -b bf16 bf16 c n n n p 482 679 2048 482 2048 482 bias,relu,clip -b f32 c n n n p 483 679 2048 483 2048 483 none -b bf16 c n n n p 483 679 2048 483 2048 483 none -b f32 bf16 c n n n p 483 679 2048 483 2048 483 bias,relu,clip -b bf16 bf16 c n n n p 483 679 2048 483 2048 483 bias,relu,clip -b f32 c n n n p 484 679 2048 484 2048 484 none -b bf16 c n n n p 484 679 2048 484 2048 484 none -b f32 bf16 c n n n p 484 679 2048 484 2048 484 bias,relu,clip -b bf16 bf16 c n n n p 484 679 2048 484 2048 484 bias,relu,clip -b f32 c n n n p 485 679 2048 485 2048 485 none -b bf16 c n n n p 485 679 2048 485 2048 485 none -b f32 bf16 c n n n p 485 679 2048 485 2048 485 bias,relu,clip -b bf16 bf16 c n n n p 485 679 2048 485 2048 485 bias,relu,clip -b f32 c n n n p 480 690 2048 480 2048 480 none -b bf16 c n n n p 480 690 2048 480 2048 480 none -b f32 bf16 c n n n p 480 690 2048 480 2048 480 bias,relu,clip -b bf16 bf16 c n n n p 480 690 2048 480 2048 480 bias,relu,clip -b f32 c n n n p 481 690 2048 481 2048 481 none -b bf16 c n n n p 481 690 2048 481 2048 481 none -b f32 bf16 c n n n p 481 690 2048 481 2048 481 bias,relu,clip -b bf16 bf16 c n n n p 481 690 2048 481 2048 481 bias,relu,clip -b f32 c n n n p 482 690 2048 482 2048 482 none -b bf16 c n n n p 482 690 2048 482 2048 482 none -b f32 bf16 c n n n p 482 690 2048 482 2048 482 bias,relu,clip -b bf16 bf16 c n n n p 482 690 2048 482 2048 482 bias,relu,clip -b f32 c n n n p 483 690 2048 483 2048 483 none -b bf16 c n n n p 483 690 2048 483 2048 483 none -b f32 bf16 c n n n p 483 690 2048 483 2048 483 bias,relu,clip -b bf16 bf16 c n n n p 483 690 2048 483 2048 483 bias,relu,clip -b f32 c n n n p 484 690 2048 484 2048 484 none -b bf16 c n n n p 484 690 2048 484 2048 484 none -b f32 bf16 c n n n p 484 690 2048 484 2048 484 bias,relu,clip -b bf16 bf16 c n n n p 484 690 2048 484 2048 484 bias,relu,clip -b f32 c n n n p 485 690 2048 485 2048 485 none -b bf16 c n n n p 485 690 2048 485 2048 485 none -b f32 bf16 c n n n p 485 690 2048 485 2048 485 bias,relu,clip -b bf16 bf16 c n n n p 485 690 2048 485 2048 485 bias,relu,clip -b f32 c n n n p 480 656 1024 480 1024 480 none -b bf16 c n n n p 480 656 1024 480 1024 480 none -b f32 bf16 c n n n p 480 656 1024 480 1024 480 bias,relu,clip -b bf16 bf16 c n n n p 480 656 1024 480 1024 480 bias,relu,clip -b f32 c n n n p 480 128 3 480 3 480 none -b bf16 c n n n p 480 128 3 480 3 480 none -b f32 bf16 c n n n p 480 128 3 480 3 480 bias,relu,clip -b bf16 bf16 c n n n p 480 128 3 480 3 480 bias,relu,clip -b f32 c n n n p 1024 512 515 1024 515 1024 none -b bf16 c n n n p 1024 512 515 1024 515 1024 none -b f32 bf16 c n n n p 1024 512 515 1024 515 1024 bias,relu,clip -b bf16 bf16 c n n n p 1024 512 515 1024 515 1024 bias,relu,clip -b f32 c n n n p 1024 2048 1024 1024 1024 1024 none -b bf16 c n n n p 1024 2048 1024 1024 1024 1024 none -b f32 bf16 c n n n p 1024 2048 1024 1024 1024 1024 bias,relu,clip -b bf16 bf16 c n n n p 1024 2048 1024 1024 1024 1024 bias,relu,clip -b f32 c n n n p 1024 2048 515 1024 515 1024 none -b bf16 c n n n p 1024 2048 515 1024 515 1024 none -b f32 bf16 c n n n p 1024 2048 515 1024 515 1024 bias,relu,clip -b bf16 bf16 c n n n p 1024 2048 515 1024 515 1024 bias,relu,clip -b f32 c p n n n 1024 1040 515 1024 515 1024 none -b bf16 c p n n n 1024 1040 515 1024 515 1024 none -b f32 bf16 c p n n n 1024 1040 515 1024 515 1024 bias,relu,clip -b bf16 bf16 c p n n n 1024 1040 515 1024 515 1024 bias,relu,clip -b f32 c p n n n 5 1029 515 5 515 5 none -b bf16 c p n n n 5 1029 515 5 515 5 none -b f32 bf16 c p n n n 5 1029 515 5 515 5 bias,relu,clip -b bf16 bf16 c p n n n 5 1029 515 5 515 5 bias,relu,clip -b f32 c p n n n 1024 1029 515 1024 515 1024 none -b bf16 c p n n n 1024 1029 515 1024 515 1024 none -b f32 bf16 c p n n n 1024 1029 515 1024 515 1024 bias,relu,clip -b bf16 bf16 c p n n n 1024 1029 515 1024 515 1024 bias,relu,clip -b f32 c p n n n 1024 1040 2050 1024 2050 1024 none -b bf16 c p n n n 1024 1040 2050 1024 2050 1024 none -b f32 bf16 c p n n n 1024 1040 2050 1024 2050 1024 bias,relu,clip -b bf16 bf16 c p n n n 1024 1040 2050 1024 2050 1024 bias,relu,clip -b f32 c p n n n 1029 1029 2050 1029 2050 1029 none -b bf16 c p n n n 1029 1029 2050 1029 2050 1029 none -b f32 bf16 c p n n n 1029 1029 2050 1029 2050 1029 bias,relu,clip -b bf16 bf16 c p n n n 1029 1029 2050 1029 2050 1029 bias,relu,clip -b f32 c p n n n 485 656 2050 485 2050 485 none -b bf16 c p n n n 485 656 2050 485 2050 485 none -b f32 bf16 c p n n n 485 656 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c p n n n 485 656 2050 485 2050 485 bias,relu,clip -b f32 c p n n n 480 672 2050 480 2050 480 none -b bf16 c p n n n 480 672 2050 480 2050 480 none -b f32 bf16 c p n n n 480 672 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c p n n n 480 672 2050 480 2050 480 bias,relu,clip -b f32 c p n n n 481 672 2050 481 2050 481 none -b bf16 c p n n n 481 672 2050 481 2050 481 none -b f32 bf16 c p n n n 481 672 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c p n n n 481 672 2050 481 2050 481 bias,relu,clip -b f32 c p n n n 482 672 2050 482 2050 482 none -b bf16 c p n n n 482 672 2050 482 2050 482 none -b f32 bf16 c p n n n 482 672 2050 482 2050 482 bias,relu,clip -b bf16 bf16 c p n n n 482 672 2050 482 2050 482 bias,relu,clip -b f32 c p n n n 483 672 2050 483 2050 483 none -b bf16 c p n n n 483 672 2050 483 2050 483 none -b f32 bf16 c p n n n 483 672 2050 483 2050 483 bias,relu,clip -b bf16 bf16 c p n n n 483 672 2050 483 2050 483 bias,relu,clip -b f32 c p n n n 484 672 2050 484 2050 484 none -b bf16 c p n n n 484 672 2050 484 2050 484 none -b f32 bf16 c p n n n 484 672 2050 484 2050 484 bias,relu,clip -b bf16 bf16 c p n n n 484 672 2050 484 2050 484 bias,relu,clip -b f32 c p n n n 485 672 2050 485 2050 485 none -b bf16 c p n n n 485 672 2050 485 2050 485 none -b f32 bf16 c p n n n 485 672 2050 485 2050 485 bias,relu,clip -b bf16 bf16 c p n n n 485 672 2050 485 2050 485 bias,relu,clip -b f32 c p n n n 480 688 2050 480 2050 480 none -b bf16 c p n n n 480 688 2050 480 2050 480 none -b f32 bf16 c p n n n 480 688 2050 480 2050 480 bias,relu,clip -b bf16 bf16 c p n n n 480 688 2050 480 2050 480 bias,relu,clip -b f32 c p n n n 481 688 2050 481 2050 481 none -b bf16 c p n n n 481 688 2050 481 2050 481 none -b f32 bf16 c p n n n 481 688 2050 481 2050 481 bias,relu,clip -b bf16 bf16 c p n n n 481 688 2050 481 2050 481 bias,relu,clip -b f32 c p n n n 1024 32 256 1024 256 1024 none -b bf16 c p n n n 1024 32 256 1024 256 1024 none -b f32 bf16 c p n n n 1024 32 256 1024 256 1024 bias,relu,clip -b bf16 bf16 c p n n n 1024 32 256 1024 256 1024 bias,relu,clip -b f32 c P n n n 1024 64 512 1024 512 1024 none -b bf16 c P n n n 1024 64 512 1024 512 1024 none -b f32 bf16 c P n n n 1024 64 512 1024 512 1024 bias,relu,clip -b bf16 bf16 c P n n n 1024 64 512 1024 512 1024 bias,relu,clip -b f32 c P n n n 64 800 320 64 320 64 none -b bf16 c P n n n 64 800 320 64 320 64 none -b f32 bf16 c P n n n 64 800 320 64 320 64 bias,relu,clip -b bf16 bf16 c P n n n 64 800 320 64 320 64 bias,relu,clip -b f32 c P n n n 64 768 512 64 512 64 none -b bf16 c P n n n 64 768 512 64 512 64 none -b f32 bf16 c P n n n 64 768 512 64 512 64 bias,relu,clip -b bf16 bf16 c P n n n 64 768 512 64 512 64 bias,relu,clip -b f32 c P n n n 16 256 512 16 512 16 none -b bf16 c P n n n 16 256 512 16 512 16 none -b f32 bf16 c P n n n 16 256 512 16 512 16 bias,relu,clip -b bf16 bf16 c P n n n 16 256 512 16 512 16 bias,relu,clip -b f32 c P n n n 128 128 128 128 128 128 none -b bf16 c P n n n 128 128 128 128 128 128 none -b f32 bf16 c P n n n 128 128 128 128 128 128 bias,relu,clip -b bf16 bf16 c P n n n 128 128 128 128 128 128 bias,relu,clip -b f32 c P n n n 256 512 256 256 256 256 none -b bf16 c P n n n 256 512 256 256 256 256 none -b f32 bf16 c P n n n 256 512 256 256 256 256 bias,relu,clip -b bf16 bf16 c P n n n 256 512 256 256 256 256 bias,relu,clip -b f32 c P n n n 1024 1024 1024 1024 1024 1024 none -b bf16 c P n n n 1024 1024 1024 1024 1024 1024 none -b f32 bf16 c P n n n 1024 1024 1024 1024 1024 1024 bias,relu,clip -b bf16 bf16 c P n n n 1024 1024 1024 1024 1024 1024 bias,relu,clip -b f32 c P n n n 480 640 1024 480 1024 480 none -b bf16 c P n n n 480 640 1024 480 1024 480 none -b f32 bf16 c P n n n 480 640 1024 480 1024 480 bias,relu,clip -b bf16 bf16 c P n n n 480 640 1024 480 1024 480 bias,relu,clip -b f32 c P n n n 480 640 256 480 256 480 none -b bf16 c P n n n 480 640 256 480 256 480 none -b f32 bf16 c P n n n 480 640 256 480 256 480 bias,relu,clip -b bf16 bf16 c P n n n 480 640 256 480 256 480 bias,relu,clip -b f32 c P n n n 8 64 32 8 32 8 none -b bf16 c P n n n 8 64 32 8 32 8 none -b f32 bf16 c P n n n 8 64 32 8 32 8 bias,relu,clip -b bf16 bf16 c P n n n 8 64 32 8 32 8 bias,relu,clip -b f32 c P n n n 9 64 32 9 32 9 none -b bf16 c P n n n 9 64 32 9 32 9 none -b f32 bf16 c P n n n 9 64 32 9 32 9 bias,relu,clip -b bf16 bf16 c P n n n 9 64 32 9 32 9 bias,relu,clip -b f32 c P n n n 10 128 64 10 64 10 none -b bf16 c P n n n 10 128 64 10 64 10 none -b f32 bf16 c P n n n 10 128 64 10 64 10 bias,relu,clip -b bf16 bf16 c P n n n 10 128 64 10 64 10 bias,relu,clip -b f32 c P n n n 8 8 8 8 8 8 none -b bf16 c P n n n 8 8 8 8 8 8 none -b f32 bf16 c P n n n 8 8 8 8 8 8 bias,relu,clip -b bf16 bf16 c P n n n 8 8 8 8 8 8 bias,relu,clip -b f32 c P n n n 12 12 12 12 12 12 none -b bf16 c P n n n 12 12 12 12 12 12 none -b f32 bf16 c P n n n 12 12 12 12 12 12 bias,relu,clip -b bf16 bf16 c P n n n 12 12 12 12 12 12 bias,relu,clip -b f32 c P n n n 25 25 25 25 25 25 none -b bf16 c P n n n 25 25 25 25 25 25 none -b f32 bf16 c P n n n 25 25 25 25 25 25 bias,relu,clip -b bf16 bf16 c P n n n 25 25 25 25 25 25 bias,relu,clip -b f32 c P n n n 25 25 20 25 20 25 none -b bf16 c P n n n 25 25 20 25 20 25 none -b f32 bf16 c P n n n 25 25 20 25 20 25 bias,relu,clip -b bf16 bf16 c P n n n 25 25 20 25 20 25 bias,relu,clip -s s16 r n n n r 480 20 2050 2050 20 20 none -s s8 r n n n r 480 20 2050 2050 20 20 none -s u8 r n n n r 480 20 2050 2050 20 20 none -s s16 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n r 481 20 2050 2050 20 20 none -s s8 r n n n r 481 20 2050 2050 20 20 none -s u8 r n n n r 481 20 2050 2050 20 20 none -s s16 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n r 482 20 2050 2050 20 20 none -s s8 r n n n r 482 20 2050 2050 20 20 none -s u8 r n n n r 482 20 2050 2050 20 20 none -s s16 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n p 483 20 2050 2050 20 20 none -s s8 r n n n p 483 20 2050 2050 20 20 none -s u8 r n n n p 483 20 2050 2050 20 20 none -s s16 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n R 484 20 2050 2050 20 20 none -s s8 r n n n R 484 20 2050 2050 20 20 none -s u8 r n n n R 484 20 2050 2050 20 20 none -s s16 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n R 485 20 2050 2050 20 20 none -s s8 r n n n R 485 20 2050 2050 20 20 none -s u8 r n n n R 485 20 2050 2050 20 20 none -s s16 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -s s8 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -s u8 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -s s16 r n n n R 480 39 2050 2050 39 39 none -s s8 r n n n R 480 39 2050 2050 39 39 none -s u8 r n n n R 480 39 2050 2050 39 39 none -s s16 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n R 481 39 2050 2050 39 39 none -s s8 r n n n R 481 39 2050 2050 39 39 none -s u8 r n n n R 481 39 2050 2050 39 39 none -s s16 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n R 482 39 2050 2050 39 39 none -s s8 r n n n R 482 39 2050 2050 39 39 none -s u8 r n n n R 482 39 2050 2050 39 39 none -s s16 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n R 483 39 2050 2050 39 39 none -s s8 r n n n R 483 39 2050 2050 39 39 none -s u8 r n n n R 483 39 2050 2050 39 39 none -s s16 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n R 484 39 2050 2050 39 39 none -s s8 r n n n R 484 39 2050 2050 39 39 none -s u8 r n n n R 484 39 2050 2050 39 39 none -s s16 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n p 485 39 2050 2050 39 39 none -s s8 r n n n p 485 39 2050 2050 39 39 none -s u8 r n n n p 485 39 2050 2050 39 39 none -s s16 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -s s8 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -s u8 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -s s16 r n n n p 480 50 2050 2050 50 50 none -s s8 r n n n p 480 50 2050 2050 50 50 none -s u8 r n n n p 480 50 2050 2050 50 50 none -s s16 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n p 481 50 2050 2050 50 50 none -s s8 r n n n p 481 50 2050 2050 50 50 none -s u8 r n n n p 481 50 2050 2050 50 50 none -s s16 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n p 482 50 2050 2050 50 50 none -s s8 r n n n p 482 50 2050 2050 50 50 none -s u8 r n n n p 482 50 2050 2050 50 50 none -s s16 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n p 483 50 2050 2050 50 50 none -s s8 r n n n p 483 50 2050 2050 50 50 none -s u8 r n n n p 483 50 2050 2050 50 50 none -s s16 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n p 484 50 2050 2050 50 50 none -s s8 r n n n p 484 50 2050 2050 50 50 none -s u8 r n n n p 484 50 2050 2050 50 50 none -s s16 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n p 485 50 2050 2050 50 50 none -s s8 r n n n p 485 50 2050 2050 50 50 none -s u8 r n n n p 485 50 2050 2050 50 50 none -s s16 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -s s8 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -s u8 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -s s16 r n n n R 480 1108 2050 2050 1108 1108 none -s s8 r n n n R 480 1108 2050 2050 1108 1108 none -s u8 r n n n R 480 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 481 1108 2050 2050 1108 1108 none -s s8 r n n n R 481 1108 2050 2050 1108 1108 none -s u8 r n n n R 481 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 482 1108 2050 2050 1108 1108 none -s s8 r n n n R 482 1108 2050 2050 1108 1108 none -s u8 r n n n R 482 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 483 1108 2050 2050 1108 1108 none -s s8 r n n n R 483 1108 2050 2050 1108 1108 none -s u8 r n n n R 483 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 484 1108 2050 2050 1108 1108 none -s s8 r n n n R 484 1108 2050 2050 1108 1108 none -s u8 r n n n R 484 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 485 1108 2050 2050 1108 1108 none -s s8 r n n n R 485 1108 2050 2050 1108 1108 none -s u8 r n n n R 485 1108 2050 2050 1108 1108 none -s s16 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -s s8 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -s u8 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -s s16 r n n n R 480 1127 2050 2050 1127 1127 none -s s8 r n n n R 480 1127 2050 2050 1127 1127 none -s u8 r n n n R 480 1127 2050 2050 1127 1127 none -s s16 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n R 481 1127 2050 2050 1127 1127 none -s s8 r n n n R 481 1127 2050 2050 1127 1127 none -s u8 r n n n R 481 1127 2050 2050 1127 1127 none -s s16 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n R 482 1127 2050 2050 1127 1127 none -s s8 r n n n R 482 1127 2050 2050 1127 1127 none -s u8 r n n n R 482 1127 2050 2050 1127 1127 none -s s16 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n R 483 1127 2050 2050 1127 1127 none -s s8 r n n n R 483 1127 2050 2050 1127 1127 none -s u8 r n n n R 483 1127 2050 2050 1127 1127 none -s s16 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n p 484 1127 2050 2050 1127 1127 none -s s8 r n n n p 484 1127 2050 2050 1127 1127 none -s u8 r n n n p 484 1127 2050 2050 1127 1127 none -s s16 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n p 485 1127 2050 2050 1127 1127 none -s s8 r n n n p 485 1127 2050 2050 1127 1127 none -s u8 r n n n p 485 1127 2050 2050 1127 1127 none -s s16 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -s s8 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -s u8 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -s s16 r n n n p 480 1138 2050 2050 1138 1138 none -s s8 r n n n p 480 1138 2050 2050 1138 1138 none -s u8 r n n n p 480 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 481 1138 2050 2050 1138 1138 none -s s8 r n n n p 481 1138 2050 2050 1138 1138 none -s u8 r n n n p 481 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 482 1138 2050 2050 1138 1138 none -s s8 r n n n p 482 1138 2050 2050 1138 1138 none -s u8 r n n n p 482 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 483 1138 2050 2050 1138 1138 none -s s8 r n n n p 483 1138 2050 2050 1138 1138 none -s u8 r n n n p 483 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 484 1138 2050 2050 1138 1138 none -s s8 r n n n p 484 1138 2050 2050 1138 1138 none -s u8 r n n n p 484 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 485 1138 2050 2050 1138 1138 none -s s8 r n n n p 485 1138 2050 2050 1138 1138 none -s u8 r n n n p 485 1138 2050 2050 1138 1138 none -s s16 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -s s8 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -s u8 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -s s16 r n n n p 1 1 3 3 1 1 none -s s8 r n n n p 1 1 3 3 1 1 none -s u8 r n n n p 1 1 3 3 1 1 none -s s16 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip -s s8 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip -s u8 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip -s s16 r n n n p 1 9 3 3 9 9 none -s s8 r n n n p 1 9 3 3 9 9 none -s u8 r n n n p 1 9 3 3 9 9 none -s s16 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip -s s8 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip -s u8 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip -s s16 r n n n p 1 2048 3 3 2048 2048 none -s s8 r n n n p 1 2048 3 3 2048 2048 none -s u8 r n n n p 1 2048 3 3 2048 2048 none -s s16 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -s s8 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -s u8 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -s s16 r n n n p 1 2048 5192 5192 2048 2048 none -s s8 r n n n p 1 2048 5192 5192 2048 2048 none -s u8 r n n n p 1 2048 5192 5192 2048 2048 none -s s16 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -s s8 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -s u8 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -s s16 r n n n p 9 1 3 3 1 1 none -s s8 r n n n p 9 1 3 3 1 1 none -s u8 r n n n p 9 1 3 3 1 1 none -s s16 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip -s s8 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip -s u8 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip -s s16 r n n n p 576 1 3500 3500 1 1 none -s s8 r n n n p 576 1 3500 3500 1 1 none -s u8 r n n n p 576 1 3500 3500 1 1 none -s s16 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -s s8 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -s u8 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -s s16 r n n n p 1 1 1 1 1 1 none -s s8 r n n n p 1 1 1 1 1 1 none -s u8 r n n n p 1 1 1 1 1 1 none -s s16 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip -s s8 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip -s u8 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip -s s16 r n n n p 102 1088 1024 1024 1088 1088 none -s s8 r n n n p 102 1088 1024 1024 1088 1088 none -s u8 r n n n p 102 1088 1024 1024 1088 1088 none -s s16 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -s s8 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -s u8 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -s s16 r n n n p 102 2048 1024 1024 2048 2048 none -s s8 r n n n p 102 2048 1024 1024 2048 2048 none -s u8 r n n n p 102 2048 1024 1024 2048 2048 none -s s16 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -s s8 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -s u8 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -s s16 r n n n p 485 656 1024 1024 656 656 none -s s8 r n n n p 485 656 1024 1024 656 656 none -s u8 r n n n p 485 656 1024 1024 656 656 none -s s16 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -s s8 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -s u8 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -s s16 r n n n p 483 656 1024 1024 656 656 none -s s8 r n n n p 483 656 1024 1024 656 656 none -s u8 r n n n p 483 656 1024 1024 656 656 none -s s16 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -s s8 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -s u8 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -s s16 r n n n p 81 128 3 3 128 128 none -s s8 r n n n p 81 128 3 3 128 128 none -s u8 r n n n p 81 128 3 3 128 128 none -s s16 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip -s s8 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip -s u8 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip -s s16 r n n n p 1022 512 515 515 512 512 none -s s8 r n n n p 1022 512 515 515 512 512 none -s u8 r n n n p 1022 512 515 515 512 512 none -s s16 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -s s8 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -s u8 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -s s16 r n n n p 74 512 515 515 512 512 none -s s8 r n n n p 74 512 515 515 512 512 none -s u8 r n n n p 74 512 515 515 512 512 none -s s16 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip -s s8 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip -s u8 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip -s s16 r n n n p 253 2048 515 515 2048 2048 none -s s8 r n n n p 253 2048 515 515 2048 2048 none -s u8 r n n n p 253 2048 515 515 2048 2048 none -s s16 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -s s8 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -s u8 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -s s16 r n n n p 8192 1040 515 515 1040 1040 none -s s8 r n n n p 8192 1040 515 515 1040 1040 none -s u8 r n n n p 8192 1040 515 515 1040 1040 none -s s16 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -s s8 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -s u8 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -s s16 r n n n p 10 1029 515 515 1029 1029 none -s s8 r n n n p 10 1029 515 515 1029 1029 none -s u8 r n n n p 10 1029 515 515 1029 1029 none -s s16 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -s s8 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -s u8 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -s s16 r n n n p 24 1040 2050 2050 1040 1040 none -s s8 r n n n p 24 1040 2050 2050 1040 1040 none -s u8 r n n n p 24 1040 2050 2050 1040 1040 none -s s16 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -s s8 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -s u8 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -s s16 r n n n p 1024 1029 2050 2050 1029 1029 none -s s8 r n n n p 1024 1029 2050 2050 1029 1029 none -s u8 r n n n p 1024 1029 2050 2050 1029 1029 none -s s16 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -s s8 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -s u8 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -s s16 r n n n p 480 660 2050 2050 660 660 none -s s8 r n n n p 480 660 2050 2050 660 660 none -s u8 r n n n p 480 660 2050 2050 660 660 none -s s16 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 481 660 2050 2050 660 660 none -s s8 r n n n p 481 660 2050 2050 660 660 none -s u8 r n n n p 481 660 2050 2050 660 660 none -s s16 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 482 660 2050 2050 660 660 none -s s8 r n n n p 482 660 2050 2050 660 660 none -s u8 r n n n p 482 660 2050 2050 660 660 none -s s16 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 483 660 2050 2050 660 660 none -s s8 r n n n p 483 660 2050 2050 660 660 none -s u8 r n n n p 483 660 2050 2050 660 660 none -s s16 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 484 660 2050 2050 660 660 none -s s8 r n n n p 484 660 2050 2050 660 660 none -s u8 r n n n p 484 660 2050 2050 660 660 none -s s16 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 485 660 2050 2050 660 660 none -s s8 r n n n p 485 660 2050 2050 660 660 none -s u8 r n n n p 485 660 2050 2050 660 660 none -s s16 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -s s8 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -s u8 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -s s16 r n n n p 480 679 2050 2050 679 679 none -s s8 r n n n p 480 679 2050 2050 679 679 none -s u8 r n n n p 480 679 2050 2050 679 679 none -s s16 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 481 679 2050 2050 679 679 none -s s8 r n n n p 481 679 2050 2050 679 679 none -s u8 r n n n p 481 679 2050 2050 679 679 none -s s16 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 482 679 2050 2050 679 679 none -s s8 r n n n p 482 679 2050 2050 679 679 none -s u8 r n n n p 482 679 2050 2050 679 679 none -s s16 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 483 679 2050 2050 679 679 none -s s8 r n n n p 483 679 2050 2050 679 679 none -s u8 r n n n p 483 679 2050 2050 679 679 none -s s16 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 484 679 2050 2050 679 679 none -s s8 r n n n p 484 679 2050 2050 679 679 none -s u8 r n n n p 484 679 2050 2050 679 679 none -s s16 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 485 679 2050 2050 679 679 none -s s8 r n n n p 485 679 2050 2050 679 679 none -s u8 r n n n p 485 679 2050 2050 679 679 none -s s16 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -s s8 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -s u8 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -s s16 r n n n p 480 690 2050 2050 690 690 none -s s8 r n n n p 480 690 2050 2050 690 690 none -s u8 r n n n p 480 690 2050 2050 690 690 none -s s16 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 481 690 2050 2050 690 690 none -s s8 r n n n p 481 690 2050 2050 690 690 none -s u8 r n n n p 481 690 2050 2050 690 690 none -s s16 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 482 690 2050 2050 690 690 none -s s8 r n n n p 482 690 2050 2050 690 690 none -s u8 r n n n p 482 690 2050 2050 690 690 none -s s16 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 483 690 2050 2050 690 690 none -s s8 r n n n p 483 690 2050 2050 690 690 none -s u8 r n n n p 483 690 2050 2050 690 690 none -s s16 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 484 690 2050 2050 690 690 none -s s8 r n n n p 484 690 2050 2050 690 690 none -s u8 r n n n p 484 690 2050 2050 690 690 none -s s16 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 485 690 2050 2050 690 690 none -s s8 r n n n p 485 690 2050 2050 690 690 none -s u8 r n n n p 485 690 2050 2050 690 690 none -s s16 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -s s8 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -s u8 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -s s16 r n n n p 480 660 2048 2048 660 660 none -s s8 r n n n p 480 660 2048 2048 660 660 none -s u8 r n n n p 480 660 2048 2048 660 660 none -s s16 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 481 660 2048 2048 660 660 none -s s8 r n n n p 481 660 2048 2048 660 660 none -s u8 r n n n p 481 660 2048 2048 660 660 none -s s16 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 482 660 2048 2048 660 660 none -s s8 r n n n p 482 660 2048 2048 660 660 none -s u8 r n n n p 482 660 2048 2048 660 660 none -s s16 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 483 660 2048 2048 660 660 none -s s8 r n n n p 483 660 2048 2048 660 660 none -s u8 r n n n p 483 660 2048 2048 660 660 none -s s16 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 484 660 2048 2048 660 660 none -s s8 r n n n p 484 660 2048 2048 660 660 none -s u8 r n n n p 484 660 2048 2048 660 660 none -s s16 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 485 660 2048 2048 660 660 none -s s8 r n n n p 485 660 2048 2048 660 660 none -s u8 r n n n p 485 660 2048 2048 660 660 none -s s16 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -s s8 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -s u8 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -s s16 r n n n p 480 679 2048 2048 679 679 none -s s8 r n n n p 480 679 2048 2048 679 679 none -s u8 r n n n p 480 679 2048 2048 679 679 none -s s16 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 481 679 2048 2048 679 679 none -s s8 r n n n p 481 679 2048 2048 679 679 none -s u8 r n n n p 481 679 2048 2048 679 679 none -s s16 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 482 679 2048 2048 679 679 none -s s8 r n n n p 482 679 2048 2048 679 679 none -s u8 r n n n p 482 679 2048 2048 679 679 none -s s16 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 483 679 2048 2048 679 679 none -s s8 r n n n p 483 679 2048 2048 679 679 none -s u8 r n n n p 483 679 2048 2048 679 679 none -s s16 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 484 679 2048 2048 679 679 none -s s8 r n n n p 484 679 2048 2048 679 679 none -s u8 r n n n p 484 679 2048 2048 679 679 none -s s16 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 485 679 2048 2048 679 679 none -s s8 r n n n p 485 679 2048 2048 679 679 none -s u8 r n n n p 485 679 2048 2048 679 679 none -s s16 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -s s8 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -s u8 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -s s16 r n n n p 480 690 2048 2048 690 690 none -s s8 r n n n p 480 690 2048 2048 690 690 none -s u8 r n n n p 480 690 2048 2048 690 690 none -s s16 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 481 690 2048 2048 690 690 none -s s8 r n n n p 481 690 2048 2048 690 690 none -s u8 r n n n p 481 690 2048 2048 690 690 none -s s16 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 482 690 2048 2048 690 690 none -s s8 r n n n p 482 690 2048 2048 690 690 none -s u8 r n n n p 482 690 2048 2048 690 690 none -s s16 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 483 690 2048 2048 690 690 none -s s8 r n n n p 483 690 2048 2048 690 690 none -s u8 r n n n p 483 690 2048 2048 690 690 none -s s16 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 484 690 2048 2048 690 690 none -s s8 r n n n p 484 690 2048 2048 690 690 none -s u8 r n n n p 484 690 2048 2048 690 690 none -s s16 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 485 690 2048 2048 690 690 none -s s8 r n n n p 485 690 2048 2048 690 690 none -s u8 r n n n p 485 690 2048 2048 690 690 none -s s16 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -s s8 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -s u8 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -s s16 r n n n p 480 656 1024 1024 656 656 none -s s8 r n n n p 480 656 1024 1024 656 656 none -s u8 r n n n p 480 656 1024 1024 656 656 none -s s16 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -s s8 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -s u8 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -s s16 r n n n p 480 128 3 3 128 128 none -s s8 r n n n p 480 128 3 3 128 128 none -s u8 r n n n p 480 128 3 3 128 128 none -s s16 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip -s s8 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip -s u8 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip -s s16 r n n n p 1024 512 515 515 512 512 none -s s8 r n n n p 1024 512 515 515 512 512 none -s u8 r n n n p 1024 512 515 515 512 512 none -s s16 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -s s8 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -s u8 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -s s16 r n n n p 1024 2048 1024 1024 2048 2048 none -s s8 r n n n p 1024 2048 1024 1024 2048 2048 none -s u8 r n n n p 1024 2048 1024 1024 2048 2048 none -s s16 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -s s8 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -s u8 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -s s16 r n n n p 1024 2048 515 515 2048 2048 none -s s8 r n n n p 1024 2048 515 515 2048 2048 none -s u8 r n n n p 1024 2048 515 515 2048 2048 none -s s16 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -s s8 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -s u8 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -s s16 r n n n p 1024 1040 515 515 1040 1040 none -s s8 r n n n p 1024 1040 515 515 1040 1040 none -s u8 r n n n p 1024 1040 515 515 1040 1040 none -s s16 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -s s8 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -s u8 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -s s16 r n n n p 5 1029 515 515 1029 1029 none -s s8 r n n n p 5 1029 515 515 1029 1029 none -s u8 r n n n p 5 1029 515 515 1029 1029 none -s s16 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -s s8 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -s u8 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -s s16 r n n n p 1024 1029 515 515 1029 1029 none -s s8 r n n n p 1024 1029 515 515 1029 1029 none -s u8 r n n n p 1024 1029 515 515 1029 1029 none -s s16 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -s s8 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -s u8 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -s s16 r n n n p 1024 1040 2050 2050 1040 1040 none -s s8 r n n n p 1024 1040 2050 2050 1040 1040 none -s u8 r n n n p 1024 1040 2050 2050 1040 1040 none -s s16 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -s s8 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -s u8 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -s s16 r n n n p 1029 1029 2050 2050 1029 1029 none -s s8 r n n n p 1029 1029 2050 2050 1029 1029 none -s u8 r n n n p 1029 1029 2050 2050 1029 1029 none -s s16 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -s s8 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -s u8 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -s s16 r n n n R 480 646 2050 2050 646 646 none -s s8 r n n n R 480 646 2050 2050 646 646 none -s u8 r n n n R 480 646 2050 2050 646 646 none -s s16 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 481 646 2050 2050 646 646 none -s s8 r n n n R 481 646 2050 2050 646 646 none -s u8 r n n n R 481 646 2050 2050 646 646 none -s s16 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 482 646 2050 2050 646 646 none -s s8 r n n n R 482 646 2050 2050 646 646 none -s u8 r n n n R 482 646 2050 2050 646 646 none -s s16 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 483 646 2050 2050 646 646 none -s s8 r n n n R 483 646 2050 2050 646 646 none -s u8 r n n n R 483 646 2050 2050 646 646 none -s s16 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 484 646 2050 2050 646 646 none -s s8 r n n n R 484 646 2050 2050 646 646 none -s u8 r n n n R 484 646 2050 2050 646 646 none -s s16 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 485 646 2050 2050 646 646 none -s s8 r n n n R 485 646 2050 2050 646 646 none -s u8 r n n n R 485 646 2050 2050 646 646 none -s s16 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -s s8 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -s u8 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -s s16 r n n n R 481 656 2050 2050 656 656 none -s s8 r n n n R 481 656 2050 2050 656 656 none -s u8 r n n n R 481 656 2050 2050 656 656 none -s s16 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -s s8 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -s u8 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -s s16 r n n n R 482 656 2050 2050 656 656 none -s s8 r n n n R 482 656 2050 2050 656 656 none -s u8 r n n n R 482 656 2050 2050 656 656 none -s s16 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -s s8 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -s u8 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -s s16 r n n n R 483 656 2050 2050 656 656 none -s s8 r n n n R 483 656 2050 2050 656 656 none -s u8 r n n n R 483 656 2050 2050 656 656 none -s s16 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -s s8 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -s u8 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -s s16 r n n n R 484 656 2050 2050 656 656 none -s s8 r n n n R 484 656 2050 2050 656 656 none -s u8 r n n n R 484 656 2050 2050 656 656 none -s s16 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -s s8 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -s u8 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -s s16 r n n n p 485 656 2050 2050 656 656 none -s s8 r n n n p 485 656 2050 2050 656 656 none -s u8 r n n n p 485 656 2050 2050 656 656 none -s s16 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -s s8 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -s u8 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -s s16 r n n n p 480 672 2050 2050 672 672 none -s s8 r n n n p 480 672 2050 2050 672 672 none -s u8 r n n n p 480 672 2050 2050 672 672 none -s s16 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 481 672 2050 2050 672 672 none -s s8 r n n n p 481 672 2050 2050 672 672 none -s u8 r n n n p 481 672 2050 2050 672 672 none -s s16 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 482 672 2050 2050 672 672 none -s s8 r n n n p 482 672 2050 2050 672 672 none -s u8 r n n n p 482 672 2050 2050 672 672 none -s s16 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 483 672 2050 2050 672 672 none -s s8 r n n n p 483 672 2050 2050 672 672 none -s u8 r n n n p 483 672 2050 2050 672 672 none -s s16 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 484 672 2050 2050 672 672 none -s s8 r n n n p 484 672 2050 2050 672 672 none -s u8 r n n n p 484 672 2050 2050 672 672 none -s s16 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 485 672 2050 2050 672 672 none -s s8 r n n n p 485 672 2050 2050 672 672 none -s u8 r n n n p 485 672 2050 2050 672 672 none -s s16 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -s s8 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -s u8 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -s s16 r n n n p 480 688 2050 2050 688 688 none -s s8 r n n n p 480 688 2050 2050 688 688 none -s u8 r n n n p 480 688 2050 2050 688 688 none -s s16 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n p 481 688 2050 2050 688 688 none -s s8 r n n n p 481 688 2050 2050 688 688 none -s u8 r n n n p 481 688 2050 2050 688 688 none -s s16 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n r 482 688 2050 2050 688 688 none -s s8 r n n n r 482 688 2050 2050 688 688 none -s u8 r n n n r 482 688 2050 2050 688 688 none -s s16 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n r 483 688 2050 2050 688 688 none -s s8 r n n n r 483 688 2050 2050 688 688 none -s u8 r n n n r 483 688 2050 2050 688 688 none -s s16 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n r 484 688 2050 2050 688 688 none -s s8 r n n n r 484 688 2050 2050 688 688 none -s u8 r n n n r 484 688 2050 2050 688 688 none -s s16 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n r 485 688 2050 2050 688 688 none -s s8 r n n n r 485 688 2050 2050 688 688 none -s u8 r n n n r 485 688 2050 2050 688 688 none -s s16 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -s s8 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -s u8 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -s s16 r n n n r 1024 512 64 64 512 512 none -s s8 r n n n r 1024 512 64 64 512 512 none -s u8 r n n n r 1024 512 64 64 512 512 none -s s16 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s s8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s u8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s s16 r n n n r 16 256 512 512 256 256 none -s s8 r n n n r 16 256 512 512 256 256 none -s u8 r n n n r 16 256 512 512 256 256 none -s s16 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip -s s8 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip -s u8 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip -s s16 r n n n r 480 640 512 512 640 640 none -s s8 r n n n r 480 640 512 512 640 640 none -s u8 r n n n r 480 640 512 512 640 640 none -s s16 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s s8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s u8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s s16 r n n n r 64 768 512 512 768 768 none -s s8 r n n n r 64 768 512 512 768 768 none -s u8 r n n n r 64 768 512 512 768 768 none -s s16 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip -s s8 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip -s u8 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip -s s16 r n n n r 128 128 128 128 128 128 none -s s8 r n n n r 128 128 128 128 128 128 none -s u8 r n n n r 128 128 128 128 128 128 none -s s16 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip -s s8 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip -s u8 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip -s s16 r n n n r 1024 64 512 512 64 64 none -s s8 r n n n r 1024 64 512 512 64 64 none -s u8 r n n n r 1024 64 512 512 64 64 none -s s16 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -s s8 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -s u8 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -s s16 r n n n r 1024 256 32 32 256 256 none -s s8 r n n n r 1024 256 32 32 256 256 none -s u8 r n n n r 1024 256 32 32 256 256 none -s s16 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -s s8 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -s u8 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -s s16 r n n n r 1024 512 64 64 512 512 none -s s8 r n n n r 1024 512 64 64 512 512 none -s u8 r n n n r 1024 512 64 64 512 512 none -s s16 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s s8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s u8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -s s16 r n n n r 480 640 512 512 640 640 none -s s8 r n n n r 480 640 512 512 640 640 none -s u8 r n n n r 480 640 512 512 640 640 none -s s16 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s s8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s u8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip -s s16 r n n n p 1024 32 256 256 32 32 none -s s8 r n n n p 1024 32 256 256 32 32 none -s u8 r n n n p 1024 32 256 256 32 32 none -s s16 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -s s8 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -s u8 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -s s16 r n n n P 1024 64 512 512 64 64 none -s s8 r n n n P 1024 64 512 512 64 64 none -s u8 r n n n P 1024 64 512 512 64 64 none -s s16 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -s s8 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -s u8 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -s s16 r n n n P 64 800 320 320 800 800 none -s s8 r n n n P 64 800 320 320 800 800 none -s u8 r n n n P 64 800 320 320 800 800 none -s s16 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip -s s8 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip -s u8 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip -s s16 r n n n P 64 768 512 512 768 768 none -s s8 r n n n P 64 768 512 512 768 768 none -s u8 r n n n P 64 768 512 512 768 768 none -s s16 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip -s s8 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip -s u8 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip -s s16 r n n n P 16 256 512 512 256 256 none -s s8 r n n n P 16 256 512 512 256 256 none -s u8 r n n n P 16 256 512 512 256 256 none -s s16 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip -s s8 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip -s u8 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip -s s16 r n n n P 128 128 128 128 128 128 none -s s8 r n n n P 128 128 128 128 128 128 none -s u8 r n n n P 128 128 128 128 128 128 none -s s16 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip -s s8 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip -s u8 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip -s s16 r n n n P 256 512 256 256 512 512 none -s s8 r n n n P 256 512 256 256 512 512 none -s u8 r n n n P 256 512 256 256 512 512 none -s s16 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip -s s8 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip -s u8 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip -s s16 r n n n P 1024 1024 1024 1024 1024 1024 none -s s8 r n n n P 1024 1024 1024 1024 1024 1024 none -s u8 r n n n P 1024 1024 1024 1024 1024 1024 none -s s16 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -s s8 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -s u8 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -s s16 r n n n P 480 640 1024 1024 640 640 none -s s8 r n n n P 480 640 1024 1024 640 640 none -s u8 r n n n P 480 640 1024 1024 640 640 none -s s16 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -s s8 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -s u8 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -s s16 r n n n P 480 640 256 256 640 640 none -s s8 r n n n P 480 640 256 256 640 640 none -s u8 r n n n P 480 640 256 256 640 640 none -s s16 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip -s s8 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip -s u8 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip -s s16 r n n n P 8 64 32 32 64 64 none -s s8 r n n n P 8 64 32 32 64 64 none -s u8 r n n n P 8 64 32 32 64 64 none -s s16 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip -s s8 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip -s u8 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip -s s16 r n n n P 9 64 32 32 64 64 none -s s8 r n n n P 9 64 32 32 64 64 none -s u8 r n n n P 9 64 32 32 64 64 none -s s16 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip -s s8 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip -s u8 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip -s s16 r n n n P 10 128 64 64 128 128 none -s s8 r n n n P 10 128 64 64 128 128 none -s u8 r n n n P 10 128 64 64 128 128 none -s s16 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip -s s8 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip -s u8 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip -s s16 r n n n P 8 8 8 8 8 8 none -s s8 r n n n P 8 8 8 8 8 8 none -s u8 r n n n P 8 8 8 8 8 8 none -s s16 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip -s s8 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip -s u8 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip -s s16 r n n n P 12 12 12 12 12 12 none -s s8 r n n n P 12 12 12 12 12 12 none -s u8 r n n n P 12 12 12 12 12 12 none -s s16 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip -s s8 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip -s u8 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip -s s16 r n n n P 25 25 25 25 25 25 none -s s8 r n n n P 25 25 25 25 25 25 none -s u8 r n n n P 25 25 25 25 25 25 none -s s16 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip -s s8 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip -s u8 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip -s s16 r n n n P 25 25 20 20 25 25 none -s s8 r n n n P 25 25 20 20 25 25 none -s u8 r n n n P 25 25 20 20 25 25 none -s s16 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip -s s8 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip -s u8 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip -i s32 r n n n p 480 20 2050 2050 20 20 none -i s8 r n n n p 480 20 2050 2050 20 20 none -i s32 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n p 481 20 2050 2050 20 20 none -i s8 r n n n p 481 20 2050 2050 20 20 none -i s32 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n p 482 20 2050 2050 20 20 none -i s8 r n n n p 482 20 2050 2050 20 20 none -i s32 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n p 483 20 2050 2050 20 20 none -i s8 r n n n p 483 20 2050 2050 20 20 none -i s32 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n R 484 20 2050 2050 20 20 none -i s8 r n n n R 484 20 2050 2050 20 20 none -i s32 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n R 485 20 2050 2050 20 20 none -i s8 r n n n R 485 20 2050 2050 20 20 none -i s32 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -i s8 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -i s32 r n n n R 480 39 2050 2050 39 39 none -i s8 r n n n R 480 39 2050 2050 39 39 none -i s32 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n R 481 39 2050 2050 39 39 none -i s8 r n n n R 481 39 2050 2050 39 39 none -i s32 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n R 482 39 2050 2050 39 39 none -i s8 r n n n R 482 39 2050 2050 39 39 none -i s32 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n R 483 39 2050 2050 39 39 none -i s8 r n n n R 483 39 2050 2050 39 39 none -i s32 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n R 484 39 2050 2050 39 39 none -i s8 r n n n R 484 39 2050 2050 39 39 none -i s32 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n p 485 39 2050 2050 39 39 none -i s8 r n n n p 485 39 2050 2050 39 39 none -i s32 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -i s8 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -i s32 r n n n p 480 50 2050 2050 50 50 none -i s8 r n n n p 480 50 2050 2050 50 50 none -i s32 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n p 481 50 2050 2050 50 50 none -i s8 r n n n p 481 50 2050 2050 50 50 none -i s32 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n p 482 50 2050 2050 50 50 none -i s8 r n n n p 482 50 2050 2050 50 50 none -i s32 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n p 483 50 2050 2050 50 50 none -i s8 r n n n p 483 50 2050 2050 50 50 none -i s32 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n p 484 50 2050 2050 50 50 none -i s8 r n n n p 484 50 2050 2050 50 50 none -i s32 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n p 485 50 2050 2050 50 50 none -i s8 r n n n p 485 50 2050 2050 50 50 none -i s32 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -i s8 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -i s32 r n n n R 480 1108 2050 2050 1108 1108 none -i s8 r n n n R 480 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 481 1108 2050 2050 1108 1108 none -i s8 r n n n R 481 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 482 1108 2050 2050 1108 1108 none -i s8 r n n n R 482 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 483 1108 2050 2050 1108 1108 none -i s8 r n n n R 483 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 484 1108 2050 2050 1108 1108 none -i s8 r n n n R 484 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 485 1108 2050 2050 1108 1108 none -i s8 r n n n R 485 1108 2050 2050 1108 1108 none -i s32 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -i s8 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -i s32 r n n n R 480 1127 2050 2050 1127 1127 none -i s8 r n n n R 480 1127 2050 2050 1127 1127 none -i s32 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n R 481 1127 2050 2050 1127 1127 none -i s8 r n n n R 481 1127 2050 2050 1127 1127 none -i s32 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n R 482 1127 2050 2050 1127 1127 none -i s8 r n n n R 482 1127 2050 2050 1127 1127 none -i s32 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n R 483 1127 2050 2050 1127 1127 none -i s8 r n n n R 483 1127 2050 2050 1127 1127 none -i s32 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n p 484 1127 2050 2050 1127 1127 none -i s8 r n n n p 484 1127 2050 2050 1127 1127 none -i s32 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n p 485 1127 2050 2050 1127 1127 none -i s8 r n n n p 485 1127 2050 2050 1127 1127 none -i s32 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -i s8 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -i s32 r n n n p 480 1138 2050 2050 1138 1138 none -i s8 r n n n p 480 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 481 1138 2050 2050 1138 1138 none -i s8 r n n n p 481 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 482 1138 2050 2050 1138 1138 none -i s8 r n n n p 482 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 483 1138 2050 2050 1138 1138 none -i s8 r n n n p 483 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 484 1138 2050 2050 1138 1138 none -i s8 r n n n p 484 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 485 1138 2050 2050 1138 1138 none -i s8 r n n n p 485 1138 2050 2050 1138 1138 none -i s32 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -i s8 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -i s32 r n n n p 1 1 3 3 1 1 none -i s8 r n n n p 1 1 3 3 1 1 none -i s32 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip -i s8 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip -i s32 r n n n p 1 9 3 3 9 9 none -i s8 r n n n p 1 9 3 3 9 9 none -i s32 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip -i s8 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip -i s32 r n n n p 1 2048 3 3 2048 2048 none -i s8 r n n n p 1 2048 3 3 2048 2048 none -i s32 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -i s8 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -i s32 r n n n p 1 2048 5192 5192 2048 2048 none -i s8 r n n n p 1 2048 5192 5192 2048 2048 none -i s32 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -i s8 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -i s32 r n n n p 9 1 3 3 1 1 none -i s8 r n n n p 9 1 3 3 1 1 none -i s32 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip -i s8 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip -i s32 r n n n p 576 1 3500 3500 1 1 none -i s8 r n n n p 576 1 3500 3500 1 1 none -i s32 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -i s8 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -i s32 r n n n p 1 1 1 1 1 1 none -i s8 r n n n p 1 1 1 1 1 1 none -i s32 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip -i s8 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip -i s32 r n n n p 102 1088 1024 1024 1088 1088 none -i s8 r n n n p 102 1088 1024 1024 1088 1088 none -i s32 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -i s8 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -i s32 r n n n p 102 2048 1024 1024 2048 2048 none -i s8 r n n n p 102 2048 1024 1024 2048 2048 none -i s32 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -i s8 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -i s32 r n n n p 485 656 1024 1024 656 656 none -i s8 r n n n p 485 656 1024 1024 656 656 none -i s32 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -i s8 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -i s32 r n n n p 483 656 1024 1024 656 656 none -i s8 r n n n p 483 656 1024 1024 656 656 none -i s32 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -i s8 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -i s32 r n n n p 81 128 3 3 128 128 none -i s8 r n n n p 81 128 3 3 128 128 none -i s32 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip -i s8 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip -i s32 r n n n p 1022 512 515 515 512 512 none -i s8 r n n n p 1022 512 515 515 512 512 none -i s32 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -i s8 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip -i s32 r n n n p 74 512 515 515 512 512 none -i s8 r n n n p 74 512 515 515 512 512 none -i s32 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip -i s8 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip -i s32 r n n n p 253 2048 515 515 2048 2048 none -i s8 r n n n p 253 2048 515 515 2048 2048 none -i s32 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -i s8 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -i s32 r n n n p 8192 1040 515 515 1040 1040 none -i s8 r n n n p 8192 1040 515 515 1040 1040 none -i s32 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -i s8 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -i s32 r n n n p 10 1029 515 515 1029 1029 none -i s8 r n n n p 10 1029 515 515 1029 1029 none -i s32 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -i s8 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -i s32 r n n n p 24 1040 2050 2050 1040 1040 none -i s8 r n n n p 24 1040 2050 2050 1040 1040 none -i s32 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -i s8 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -i s32 r n n n p 1024 1029 2050 2050 1029 1029 none -i s8 r n n n p 1024 1029 2050 2050 1029 1029 none -i s32 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -i s8 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -i s32 r n n n p 480 660 2050 2050 660 660 none -i s8 r n n n p 480 660 2050 2050 660 660 none -i s32 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 481 660 2050 2050 660 660 none -i s8 r n n n p 481 660 2050 2050 660 660 none -i s32 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 482 660 2050 2050 660 660 none -i s8 r n n n p 482 660 2050 2050 660 660 none -i s32 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 483 660 2050 2050 660 660 none -i s8 r n n n p 483 660 2050 2050 660 660 none -i s32 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 484 660 2050 2050 660 660 none -i s8 r n n n p 484 660 2050 2050 660 660 none -i s32 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 485 660 2050 2050 660 660 none -i s8 r n n n p 485 660 2050 2050 660 660 none -i s32 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -i s8 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -i s32 r n n n p 480 679 2050 2050 679 679 none -i s8 r n n n p 480 679 2050 2050 679 679 none -i s32 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 481 679 2050 2050 679 679 none -i s8 r n n n p 481 679 2050 2050 679 679 none -i s32 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 482 679 2050 2050 679 679 none -i s8 r n n n p 482 679 2050 2050 679 679 none -i s32 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 483 679 2050 2050 679 679 none -i s8 r n n n p 483 679 2050 2050 679 679 none -i s32 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 484 679 2050 2050 679 679 none -i s8 r n n n p 484 679 2050 2050 679 679 none -i s32 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 485 679 2050 2050 679 679 none -i s8 r n n n p 485 679 2050 2050 679 679 none -i s32 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -i s8 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -i s32 r n n n p 480 690 2050 2050 690 690 none -i s8 r n n n p 480 690 2050 2050 690 690 none -i s32 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 481 690 2050 2050 690 690 none -i s8 r n n n p 481 690 2050 2050 690 690 none -i s32 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 482 690 2050 2050 690 690 none -i s8 r n n n p 482 690 2050 2050 690 690 none -i s32 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 483 690 2050 2050 690 690 none -i s8 r n n n p 483 690 2050 2050 690 690 none -i s32 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 484 690 2050 2050 690 690 none -i s8 r n n n p 484 690 2050 2050 690 690 none -i s32 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 485 690 2050 2050 690 690 none -i s8 r n n n p 485 690 2050 2050 690 690 none -i s32 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -i s8 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -i s32 r n n n p 480 660 2048 2048 660 660 none -i s8 r n n n p 480 660 2048 2048 660 660 none -i s32 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 481 660 2048 2048 660 660 none -i s8 r n n n p 481 660 2048 2048 660 660 none -i s32 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 482 660 2048 2048 660 660 none -i s8 r n n n p 482 660 2048 2048 660 660 none -i s32 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 483 660 2048 2048 660 660 none -i s8 r n n n p 483 660 2048 2048 660 660 none -i s32 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 484 660 2048 2048 660 660 none -i s8 r n n n p 484 660 2048 2048 660 660 none -i s32 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 485 660 2048 2048 660 660 none -i s8 r n n n p 485 660 2048 2048 660 660 none -i s32 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -i s8 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -i s32 r n n n p 480 679 2048 2048 679 679 none -i s8 r n n n p 480 679 2048 2048 679 679 none -i s32 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 481 679 2048 2048 679 679 none -i s8 r n n n p 481 679 2048 2048 679 679 none -i s32 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 482 679 2048 2048 679 679 none -i s8 r n n n p 482 679 2048 2048 679 679 none -i s32 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 483 679 2048 2048 679 679 none -i s8 r n n n p 483 679 2048 2048 679 679 none -i s32 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 484 679 2048 2048 679 679 none -i s8 r n n n p 484 679 2048 2048 679 679 none -i s32 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 485 679 2048 2048 679 679 none -i s8 r n n n p 485 679 2048 2048 679 679 none -i s32 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -i s8 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -i s32 r n n n p 480 690 2048 2048 690 690 none -i s8 r n n n p 480 690 2048 2048 690 690 none -i s32 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 481 690 2048 2048 690 690 none -i s8 r n n n p 481 690 2048 2048 690 690 none -i s32 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 482 690 2048 2048 690 690 none -i s8 r n n n p 482 690 2048 2048 690 690 none -i s32 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 483 690 2048 2048 690 690 none -i s8 r n n n p 483 690 2048 2048 690 690 none -i s32 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 484 690 2048 2048 690 690 none -i s8 r n n n p 484 690 2048 2048 690 690 none -i s32 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 485 690 2048 2048 690 690 none -i s8 r n n n p 485 690 2048 2048 690 690 none -i s32 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -i s8 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -i s32 r n n n p 480 656 1024 1024 656 656 none -i s8 r n n n p 480 656 1024 1024 656 656 none -i s32 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -i s8 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -i s32 r n n n p 480 128 3 3 128 128 none -i s8 r n n n p 480 128 3 3 128 128 none -i s32 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip -i s8 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip -i s32 r n n n p 1024 512 515 515 512 512 none -i s8 r n n n p 1024 512 515 515 512 512 none -i s32 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -i s8 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip -i s32 r n n n p 1024 2048 1024 1024 2048 2048 none -i s8 r n n n p 1024 2048 1024 1024 2048 2048 none -i s32 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -i s8 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -i s32 r n n n p 1024 2048 515 515 2048 2048 none -i s8 r n n n p 1024 2048 515 515 2048 2048 none -i s32 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -i s8 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -i s32 r n n n p 1024 1040 515 515 1040 1040 none -i s8 r n n n p 1024 1040 515 515 1040 1040 none -i s32 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -i s8 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -i s32 r n n n p 5 1029 515 515 1029 1029 none -i s8 r n n n p 5 1029 515 515 1029 1029 none -i s32 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -i s8 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -i s32 r n n n p 1024 1029 515 515 1029 1029 none -i s8 r n n n p 1024 1029 515 515 1029 1029 none -i s32 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -i s8 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -i s32 r n n n p 1024 1040 2050 2050 1040 1040 none -i s8 r n n n p 1024 1040 2050 2050 1040 1040 none -i s32 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -i s8 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -i s32 r n n n p 1029 1029 2050 2050 1029 1029 none -i s8 r n n n p 1029 1029 2050 2050 1029 1029 none -i s32 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -i s8 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -i s32 r n n n R 480 646 2050 2050 646 646 none -i s8 r n n n R 480 646 2050 2050 646 646 none -i s32 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 481 646 2050 2050 646 646 none -i s8 r n n n R 481 646 2050 2050 646 646 none -i s32 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 482 646 2050 2050 646 646 none -i s8 r n n n R 482 646 2050 2050 646 646 none -i s32 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 483 646 2050 2050 646 646 none -i s8 r n n n R 483 646 2050 2050 646 646 none -i s32 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 484 646 2050 2050 646 646 none -i s8 r n n n R 484 646 2050 2050 646 646 none -i s32 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 485 646 2050 2050 646 646 none -i s8 r n n n R 485 646 2050 2050 646 646 none -i s32 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -i s8 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -i s32 r n n n R 481 656 2050 2050 656 656 none -i s8 r n n n R 481 656 2050 2050 656 656 none -i s32 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -i s8 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -i s32 r n n n R 482 656 2050 2050 656 656 none -i s8 r n n n R 482 656 2050 2050 656 656 none -i s32 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -i s8 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -i s32 r n n n R 483 656 2050 2050 656 656 none -i s8 r n n n R 483 656 2050 2050 656 656 none -i s32 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -i s8 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -i s32 r n n n R 484 656 2050 2050 656 656 none -i s8 r n n n R 484 656 2050 2050 656 656 none -i s32 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -i s8 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -i s32 r n n n p 485 656 2050 2050 656 656 none -i s8 r n n n p 485 656 2050 2050 656 656 none -i s32 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -i s8 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -i s32 r n n n p 480 672 2050 2050 672 672 none -i s8 r n n n p 480 672 2050 2050 672 672 none -i s32 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 481 672 2050 2050 672 672 none -i s8 r n n n p 481 672 2050 2050 672 672 none -i s32 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 482 672 2050 2050 672 672 none -i s8 r n n n p 482 672 2050 2050 672 672 none -i s32 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 483 672 2050 2050 672 672 none -i s8 r n n n p 483 672 2050 2050 672 672 none -i s32 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 484 672 2050 2050 672 672 none -i s8 r n n n p 484 672 2050 2050 672 672 none -i s32 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 485 672 2050 2050 672 672 none -i s8 r n n n p 485 672 2050 2050 672 672 none -i s32 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -i s8 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -i s32 r n n n p 480 688 2050 2050 688 688 none -i s8 r n n n p 480 688 2050 2050 688 688 none -i s32 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n p 481 688 2050 2050 688 688 none -i s8 r n n n p 481 688 2050 2050 688 688 none -i s32 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n r 482 688 2050 2050 688 688 none -i s8 r n n n r 482 688 2050 2050 688 688 none -i s32 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n r 483 688 2050 2050 688 688 none -i s8 r n n n r 483 688 2050 2050 688 688 none -i s32 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n r 484 688 2050 2050 688 688 none -i s8 r n n n r 484 688 2050 2050 688 688 none -i s32 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n r 485 688 2050 2050 688 688 none -i s8 r n n n r 485 688 2050 2050 688 688 none -i s32 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -i s8 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -i s32 r n n n r 1024 512 64 64 512 512 none -i s8 r n n n r 1024 512 64 64 512 512 none -i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s32 r n n n r 16 256 512 512 256 256 none -i s8 r n n n r 16 256 512 512 256 256 none -i s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -i s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -i s32 r n n n r 480 640 512 512 640 640 none -i s8 r n n n r 480 640 512 512 640 640 none -i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s32 r n n n r 64 768 512 512 768 768 none -i s8 r n n n r 64 768 512 512 768 768 none -i s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -i s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -i s32 r n n n r 128 128 128 128 128 128 none -i s8 r n n n r 128 128 128 128 128 128 none -i s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -i s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -i s32 r n n n r 1024 64 512 512 64 64 none -i s8 r n n n r 1024 64 512 512 64 64 none -i s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -i s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -i s32 r n n n r 1024 256 32 32 256 256 none -i s8 r n n n r 1024 256 32 32 256 256 none -i s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -i s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -i s32 r n n n r 1024 512 64 64 512 512 none -i s8 r n n n r 1024 512 64 64 512 512 none -i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s32 r n n n r 480 640 512 512 640 640 none -i s8 r n n n r 480 640 512 512 640 640 none -i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s32 r n n n p 1024 32 256 256 32 32 none -i s8 r n n n p 1024 32 256 256 32 32 none -i s32 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -i s8 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip -i s32 r n n n P 1024 64 512 512 64 64 none -i s8 r n n n P 1024 64 512 512 64 64 none -i s32 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -i s8 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip -i s32 r n n n P 64 800 320 320 800 800 none -i s8 r n n n P 64 800 320 320 800 800 none -i s32 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip -i s8 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip -i s32 r n n n P 64 768 512 512 768 768 none -i s8 r n n n P 64 768 512 512 768 768 none -i s32 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip -i s8 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip -i s32 r n n n P 16 256 512 512 256 256 none -i s8 r n n n P 16 256 512 512 256 256 none -i s32 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip -i s8 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip -i s32 r n n n P 128 128 128 128 128 128 none -i s8 r n n n P 128 128 128 128 128 128 none -i s32 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip -i s8 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip -i s32 r n n n P 256 512 256 256 512 512 none -i s8 r n n n P 256 512 256 256 512 512 none -i s32 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip -i s8 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip -i s32 r n n n P 1024 1024 1024 1024 1024 1024 none -i s8 r n n n P 1024 1024 1024 1024 1024 1024 none -i s32 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -i s8 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -i s32 r n n n P 480 640 1024 1024 640 640 none -i s8 r n n n P 480 640 1024 1024 640 640 none -i s32 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -i s8 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -i s32 r n n n P 480 640 256 256 640 640 none -i s8 r n n n P 480 640 256 256 640 640 none -i s32 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip -i s8 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip -i s32 r n n n P 8 64 32 32 64 64 none -i s8 r n n n P 8 64 32 32 64 64 none -i s32 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip -i s8 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip -i s32 r n n n P 9 64 32 32 64 64 none -i s8 r n n n P 9 64 32 32 64 64 none -i s32 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip -i s8 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip -i s32 r n n n P 10 128 64 64 128 128 none -i s8 r n n n P 10 128 64 64 128 128 none -i s32 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip -i s8 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip -i s32 r n n n P 8 8 8 8 8 8 none -i s8 r n n n P 8 8 8 8 8 8 none -i s32 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip -i s8 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip -i s32 r n n n P 12 12 12 12 12 12 none -i s8 r n n n P 12 12 12 12 12 12 none -i s32 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip -i s8 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip -i s32 r n n n P 25 25 25 25 25 25 none -i s8 r n n n P 25 25 25 25 25 25 none -i s32 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip -i s8 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip -i s32 r n n n P 25 25 20 20 25 25 none -i s8 r n n n P 25 25 20 20 25 25 none -i s32 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip -i s8 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip -f f32 r n n n p 480 20 2050 2050 20 20 none -f f32 f32 r n n n p 480 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n p 481 20 2050 2050 20 20 none -f f32 f32 r n n n p 481 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n p 482 20 2050 2050 20 20 none -f f32 f32 r n n n p 482 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n p 483 20 2050 2050 20 20 none -f f32 f32 r n n n p 483 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n R 484 20 2050 2050 20 20 none -f f32 f32 r n n n R 484 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n R 485 20 2050 2050 20 20 none -f f32 f32 r n n n R 485 20 2050 2050 20 20 bias,relu,clip -f f32 r n n n R 480 39 2050 2050 39 39 none -f f32 f32 r n n n R 480 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n R 481 39 2050 2050 39 39 none -f f32 f32 r n n n R 481 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n R 482 39 2050 2050 39 39 none -f f32 f32 r n n n R 482 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n R 483 39 2050 2050 39 39 none -f f32 f32 r n n n R 483 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n R 484 39 2050 2050 39 39 none -f f32 f32 r n n n R 484 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n p 485 39 2050 2050 39 39 none -f f32 f32 r n n n p 485 39 2050 2050 39 39 bias,relu,clip -f f32 r n n n p 480 50 2050 2050 50 50 none -f f32 f32 r n n n p 480 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n p 481 50 2050 2050 50 50 none -f f32 f32 r n n n p 481 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n p 482 50 2050 2050 50 50 none -f f32 f32 r n n n p 482 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n p 483 50 2050 2050 50 50 none -f f32 f32 r n n n p 483 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n p 484 50 2050 2050 50 50 none -f f32 f32 r n n n p 484 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n p 485 50 2050 2050 50 50 none -f f32 f32 r n n n p 485 50 2050 2050 50 50 bias,relu,clip -f f32 r n n n R 480 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 481 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 482 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 483 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 484 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 485 1108 2050 2050 1108 1108 none -f f32 f32 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip -f f32 r n n n R 480 1127 2050 2050 1127 1127 none -f f32 f32 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n R 481 1127 2050 2050 1127 1127 none -f f32 f32 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n R 482 1127 2050 2050 1127 1127 none -f f32 f32 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n R 483 1127 2050 2050 1127 1127 none -f f32 f32 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n p 484 1127 2050 2050 1127 1127 none -f f32 f32 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n p 485 1127 2050 2050 1127 1127 none -f f32 f32 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip -f f32 r n n n p 480 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 481 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 482 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 483 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 484 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 485 1138 2050 2050 1138 1138 none -f f32 f32 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip -f f32 r n n n p 1 1 3 3 1 1 none -f f32 f32 r n n n p 1 1 3 3 1 1 bias,relu,clip -f f32 r n n n p 1 9 3 3 9 9 none -f f32 f32 r n n n p 1 9 3 3 9 9 bias,relu,clip -f f32 r n n n p 1 2048 3 3 2048 2048 none -f f32 f32 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip -f f32 r n n n p 1 2048 5192 5192 2048 2048 none -f f32 f32 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip -f f32 r n n n p 9 1 3 3 1 1 none -f f32 f32 r n n n p 9 1 3 3 1 1 bias,relu,clip -f f32 r n n n p 576 1 3500 3500 1 1 none -f f32 f32 r n n n p 576 1 3500 3500 1 1 bias,relu,clip -f f32 r n n n p 1 1 1 1 1 1 none -f f32 f32 r n n n p 1 1 1 1 1 1 bias,relu,clip -f f32 r n n n p 102 1088 1024 1024 1088 1088 none -f f32 f32 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip -f f32 r n n n p 102 2048 1024 1024 2048 2048 none -f f32 f32 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip -f f32 r n n n p 485 656 1024 1024 656 656 none -f f32 f32 r n n n p 485 656 1024 1024 656 656 bias,relu,clip -f f32 r n n n p 483 656 1024 1024 656 656 none -f f32 f32 r n n n p 483 656 1024 1024 656 656 bias,relu,clip -f f32 r n n n p 81 128 3 3 128 128 none -f f32 f32 r n n n p 81 128 3 3 128 128 bias,relu,clip -f f32 r n n n p 1022 512 515 515 512 512 none -f f32 f32 r n n n p 1022 512 515 515 512 512 bias,relu,clip -f f32 r n n n p 74 512 515 515 512 512 none -f f32 f32 r n n n p 74 512 515 515 512 512 bias,relu,clip -f f32 r n n n p 253 2048 515 515 2048 2048 none -f f32 f32 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip -f f32 r n n n p 8192 1040 515 515 1040 1040 none -f f32 f32 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip -f f32 r n n n p 10 1029 515 515 1029 1029 none -f f32 f32 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip -f f32 r n n n p 24 1040 2050 2050 1040 1040 none -f f32 f32 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip -f f32 r n n n p 1024 1029 2050 2050 1029 1029 none -f f32 f32 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip -f f32 r n n n p 480 660 2050 2050 660 660 none -f f32 f32 r n n n p 480 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 481 660 2050 2050 660 660 none -f f32 f32 r n n n p 481 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 482 660 2050 2050 660 660 none -f f32 f32 r n n n p 482 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 483 660 2050 2050 660 660 none -f f32 f32 r n n n p 483 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 484 660 2050 2050 660 660 none -f f32 f32 r n n n p 484 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 485 660 2050 2050 660 660 none -f f32 f32 r n n n p 485 660 2050 2050 660 660 bias,relu,clip -f f32 r n n n p 480 679 2050 2050 679 679 none -f f32 f32 r n n n p 480 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 481 679 2050 2050 679 679 none -f f32 f32 r n n n p 481 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 482 679 2050 2050 679 679 none -f f32 f32 r n n n p 482 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 483 679 2050 2050 679 679 none -f f32 f32 r n n n p 483 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 484 679 2050 2050 679 679 none -f f32 f32 r n n n p 484 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 485 679 2050 2050 679 679 none -f f32 f32 r n n n p 485 679 2050 2050 679 679 bias,relu,clip -f f32 r n n n p 480 690 2050 2050 690 690 none -f f32 f32 r n n n p 480 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 481 690 2050 2050 690 690 none -f f32 f32 r n n n p 481 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 482 690 2050 2050 690 690 none -f f32 f32 r n n n p 482 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 483 690 2050 2050 690 690 none -f f32 f32 r n n n p 483 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 484 690 2050 2050 690 690 none -f f32 f32 r n n n p 484 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 485 690 2050 2050 690 690 none -f f32 f32 r n n n p 485 690 2050 2050 690 690 bias,relu,clip -f f32 r n n n p 480 660 2048 2048 660 660 none -f f32 f32 r n n n p 480 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 481 660 2048 2048 660 660 none -f f32 f32 r n n n p 481 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 482 660 2048 2048 660 660 none -f f32 f32 r n n n p 482 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 483 660 2048 2048 660 660 none -f f32 f32 r n n n p 483 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 484 660 2048 2048 660 660 none -f f32 f32 r n n n p 484 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 485 660 2048 2048 660 660 none -f f32 f32 r n n n p 485 660 2048 2048 660 660 bias,relu,clip -f f32 r n n n p 480 679 2048 2048 679 679 none -f f32 f32 r n n n p 480 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 481 679 2048 2048 679 679 none -f f32 f32 r n n n p 481 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 482 679 2048 2048 679 679 none -f f32 f32 r n n n p 482 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 483 679 2048 2048 679 679 none -f f32 f32 r n n n p 483 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 484 679 2048 2048 679 679 none -f f32 f32 r n n n p 484 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 485 679 2048 2048 679 679 none -f f32 f32 r n n n p 485 679 2048 2048 679 679 bias,relu,clip -f f32 r n n n p 480 690 2048 2048 690 690 none -f f32 f32 r n n n p 480 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 481 690 2048 2048 690 690 none -f f32 f32 r n n n p 481 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 482 690 2048 2048 690 690 none -f f32 f32 r n n n p 482 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 483 690 2048 2048 690 690 none -f f32 f32 r n n n p 483 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 484 690 2048 2048 690 690 none -f f32 f32 r n n n p 484 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 485 690 2048 2048 690 690 none -f f32 f32 r n n n p 485 690 2048 2048 690 690 bias,relu,clip -f f32 r n n n p 480 656 1024 1024 656 656 none -f f32 f32 r n n n p 480 656 1024 1024 656 656 bias,relu,clip -f f32 r n n n p 480 128 3 3 128 128 none -f f32 f32 r n n n p 480 128 3 3 128 128 bias,relu,clip -f f32 r n n n p 1024 512 515 515 512 512 none -f f32 f32 r n n n p 1024 512 515 515 512 512 bias,relu,clip -f f32 r n n n p 1024 2048 1024 1024 2048 2048 none -f f32 f32 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip -f f32 r n n n p 1024 2048 515 515 2048 2048 none -f f32 f32 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip -f f32 r n n n p 1024 1040 515 515 1040 1040 none -f f32 f32 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip -f f32 r n n n p 5 1029 515 515 1029 1029 none -f f32 f32 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip -f f32 r n n n p 1024 1029 515 515 1029 1029 none -f f32 f32 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip -f f32 r n n n p 1024 1040 2050 2050 1040 1040 none -f f32 f32 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip -f f32 r n n n p 1029 1029 2050 2050 1029 1029 none -f f32 f32 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip -f f32 r n n n R 480 646 2050 2050 646 646 none -f f32 f32 r n n n R 480 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 481 646 2050 2050 646 646 none -f f32 f32 r n n n R 481 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 482 646 2050 2050 646 646 none -f f32 f32 r n n n R 482 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 483 646 2050 2050 646 646 none -f f32 f32 r n n n R 483 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 484 646 2050 2050 646 646 none -f f32 f32 r n n n R 484 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 485 646 2050 2050 646 646 none -f f32 f32 r n n n R 485 646 2050 2050 646 646 bias,relu,clip -f f32 r n n n R 481 656 2050 2050 656 656 none -f f32 f32 r n n n R 481 656 2050 2050 656 656 bias,relu,clip -f f32 r n n n R 482 656 2050 2050 656 656 none -f f32 f32 r n n n R 482 656 2050 2050 656 656 bias,relu,clip -f f32 r n n n R 483 656 2050 2050 656 656 none -f f32 f32 r n n n R 483 656 2050 2050 656 656 bias,relu,clip -f f32 r n n n R 484 656 2050 2050 656 656 none -f f32 f32 r n n n R 484 656 2050 2050 656 656 bias,relu,clip -f f32 r n n n p 485 656 2050 2050 656 656 none -f f32 f32 r n n n p 485 656 2050 2050 656 656 bias,relu,clip -f f32 r n n n p 480 672 2050 2050 672 672 none -f f32 f32 r n n n p 480 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 481 672 2050 2050 672 672 none -f f32 f32 r n n n p 481 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 482 672 2050 2050 672 672 none -f f32 f32 r n n n p 482 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 483 672 2050 2050 672 672 none -f f32 f32 r n n n p 483 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 484 672 2050 2050 672 672 none -f f32 f32 r n n n p 484 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 485 672 2050 2050 672 672 none -f f32 f32 r n n n p 485 672 2050 2050 672 672 bias,relu,clip -f f32 r n n n p 480 688 2050 2050 688 688 none -f f32 f32 r n n n p 480 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n p 481 688 2050 2050 688 688 none -f f32 f32 r n n n p 481 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n r 482 688 2050 2050 688 688 none -f f32 f32 r n n n r 482 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n r 483 688 2050 2050 688 688 none -f f32 f32 r n n n r 483 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n r 484 688 2050 2050 688 688 none -f f32 f32 r n n n r 484 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n r 485 688 2050 2050 688 688 none -f f32 f32 r n n n r 485 688 2050 2050 688 688 bias,relu,clip -f f32 r n n n r 1024 512 64 64 512 512 none -f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip -f f32 r n n n r 16 256 512 512 256 256 none -f f32 f32 r n n n r 16 256 512 512 256 256 bias,relu,clip -f f32 r n n n r 480 640 512 512 640 640 none -f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip -f f32 r n n n r 64 768 512 512 768 768 none -f f32 f32 r n n n r 64 768 512 512 768 768 bias,relu,clip -f f32 r n n n r 128 128 128 128 128 128 none -f f32 f32 r n n n r 128 128 128 128 128 128 bias,relu,clip -f f32 r n n n r 1024 64 512 512 64 64 none -f f32 f32 r n n n r 1024 64 512 512 64 64 bias,relu,clip -f f32 r n n n r 1024 256 32 32 256 256 none -f f32 f32 r n n n r 1024 256 32 32 256 256 bias,relu,clip -f f32 r n n n r 1024 512 64 64 512 512 none -f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip -f f32 r n n n r 480 640 512 512 640 640 none -f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip -f f32 r n n n p 1024 32 256 256 32 32 none -f f32 f32 r n n n p 1024 32 256 256 32 32 bias,relu,clip -f f32 r n n n P 1024 64 512 512 64 64 none -f f32 f32 r n n n P 1024 64 512 512 64 64 bias,relu,clip -f f32 r n n n P 64 800 320 320 800 800 none -f f32 f32 r n n n P 64 800 320 320 800 800 bias,relu,clip -f f32 r n n n P 64 768 512 512 768 768 none -f f32 f32 r n n n P 64 768 512 512 768 768 bias,relu,clip -f f32 r n n n P 16 256 512 512 256 256 none -f f32 f32 r n n n P 16 256 512 512 256 256 bias,relu,clip -f f32 r n n n P 128 128 128 128 128 128 none -f f32 f32 r n n n P 128 128 128 128 128 128 bias,relu,clip -f f32 r n n n P 256 512 256 256 512 512 none -f f32 f32 r n n n P 256 512 256 256 512 512 bias,relu,clip -f f32 r n n n P 1024 1024 1024 1024 1024 1024 none -f f32 f32 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip -f f32 r n n n P 480 640 1024 1024 640 640 none -f f32 f32 r n n n P 480 640 1024 1024 640 640 bias,relu,clip -f f32 r n n n P 480 640 256 256 640 640 none -f f32 f32 r n n n P 480 640 256 256 640 640 bias,relu,clip -f f32 r n n n P 8 64 32 32 64 64 none -f f32 f32 r n n n P 8 64 32 32 64 64 bias,relu,clip -f f32 r n n n P 9 64 32 32 64 64 none -f f32 f32 r n n n P 9 64 32 32 64 64 bias,relu,clip -f f32 r n n n P 10 128 64 64 128 128 none -f f32 f32 r n n n P 10 128 64 64 128 128 bias,relu,clip -f f32 r n n n P 8 8 8 8 8 8 none -f f32 f32 r n n n P 8 8 8 8 8 8 bias,relu,clip -f f32 r n n n P 12 12 12 12 12 12 none -f f32 f32 r n n n P 12 12 12 12 12 12 bias,relu,clip -f f32 r n n n P 25 25 25 25 25 25 none -f f32 f32 r n n n P 25 25 25 25 25 25 bias,relu,clip -f f32 r n n n P 25 25 20 20 25 25 none -f f32 f32 r n n n P 25 25 20 20 25 25 bias,relu,clip -i s32 r n n n r 4096 256 5 5 256 256 none -i s8 r n n n r 4096 256 5 5 256 256 none -i s32 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip -i s8 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip -i s32 r n n n r 3000 256 128 128 256 256 none -i s8 r n n n r 3000 256 128 128 256 256 none -i s32 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip -i s8 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip -i s32 r n n n r 4096 1024 512 512 1024 1024 none -i s8 r n n n r 4096 1024 512 512 1024 1024 none -i s32 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip -i s8 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip -i s32 r n n n r 144 256 5 5 256 256 none -i s8 r n n n r 144 256 5 5 256 256 none -i s32 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip -i s8 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip -i s32 r n n n r 144 256 128 128 256 256 none -i s8 r n n n r 144 256 128 128 256 256 none -i s32 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip -i s8 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip -i s32 r n n n r 144 1024 512 512 1024 1024 none -i s8 r n n n r 144 1024 512 512 1024 1024 none -i s32 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip -i s8 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip -i s32 r n n n r 480 688 256 256 688 688 none -i s8 r n n n r 480 688 256 256 688 688 none -i s32 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip -i s8 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip -i s32 r n n n r 480 640 512 512 640 640 none -i s8 r n n n r 480 640 512 512 640 640 none -i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip -i s32 r n n n r 480 640 1024 1024 640 640 none -i s8 r n n n r 480 640 1024 1024 640 640 none -i s32 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip -i s8 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip -i s32 r n n n r 64 800 320 320 800 800 none -i s8 r n n n r 64 800 320 320 800 800 none -i s32 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip -i s8 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip -i s32 r n n n r 64 768 512 512 768 768 none -i s8 r n n n r 64 768 512 512 768 768 none -i s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -i s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip -i s32 r n n n r 16 256 512 512 256 256 none -i s8 r n n n r 16 256 512 512 256 256 none -i s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -i s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip -i s32 r n n n r 128 128 128 128 128 128 none -i s8 r n n n r 128 128 128 128 128 128 none -i s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -i s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip -i s32 r n n n r 256 512 256 256 512 512 none -i s8 r n n n r 256 512 256 256 512 512 none -i s32 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip -i s8 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip -i s32 r n n n r 1024 1024 1024 1024 1024 1024 none -i s8 r n n n r 1024 1024 1024 1024 1024 1024 none -i s32 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip -i s8 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip -i s32 r n n n r 1024 32 256 256 32 32 none -i s8 r n n n r 1024 32 256 256 32 32 none -i s32 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip -i s8 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip -i s32 r n n n r 1024 64 512 512 64 64 none -i s8 r n n n r 1024 64 512 512 64 64 none -i s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -i s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip -i s32 r n n n r 1024 256 32 32 256 256 none -i s8 r n n n r 1024 256 32 32 256 256 none -i s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -i s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip -i s32 r n n n r 1024 512 64 64 512 512 none -i s8 r n n n r 1024 512 64 64 512 512 none -i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip -i s32 r n n n r 512 32 256 256 32 32 none -i s8 r n n n r 512 32 256 256 32 32 none -i s32 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip -i s8 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip -i s32 r n n n r 512 768 512 512 768 768 none -i s8 r n n n r 512 768 512 512 768 768 none -i s32 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip -i s8 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip -i s32 r n n n r 512 256 32 32 256 256 none -i s8 r n n n r 512 256 32 32 256 256 none -i s32 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip -i s8 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip -i s32 r n n n r 512 512 64 64 512 512 none -i s8 r n n n r 512 512 64 64 512 512 none -i s32 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip -i s8 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip -i s32 r n n n r 512 256 768 768 256 256 none -i s8 r n n n r 512 256 768 768 256 256 none -i s32 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip -i s8 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip -i s32 r n n n r 768 768 1024 1024 768 768 none -i s8 r n n n r 768 768 1024 1024 768 768 none -i s32 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip -i s8 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip -i s32 r n n n r 768 768 768 768 768 768 none -i s8 r n n n r 768 768 768 768 768 768 none -i s32 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip -i s8 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip -i s32 r n n n r 2048 2048 2048 2048 2048 2048 none -i s8 r n n n r 2048 2048 2048 2048 2048 2048 none -i s32 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip -i s8 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip -i s32 r n n n r 4096 4096 4096 4096 4096 4096 none -i s8 r n n n r 4096 4096 4096 4096 4096 4096 none -i s32 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip -i s8 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip -f f32 r n n n r 4096 256 5 5 256 256 none -f f32 f32 r n n n r 4096 256 5 5 256 256 bias,relu,clip -f f32 r n n n r 3000 256 128 128 256 256 none -f f32 f32 r n n n r 3000 256 128 128 256 256 bias,relu,clip -f f32 r n n n r 4096 1024 512 512 1024 1024 none -f f32 f32 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip -f f32 r n n n r 144 256 5 5 256 256 none -f f32 f32 r n n n r 144 256 5 5 256 256 bias,relu,clip -f f32 r n n n r 144 256 128 128 256 256 none -f f32 f32 r n n n r 144 256 128 128 256 256 bias,relu,clip -f f32 r n n n r 144 1024 512 512 1024 1024 none -f f32 f32 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip -f f32 r n n n r 480 688 256 256 688 688 none -f f32 f32 r n n n r 480 688 256 256 688 688 bias,relu,clip -f f32 r n n n r 480 640 512 512 640 640 none -f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip -f f32 r n n n r 480 640 1024 1024 640 640 none -f f32 f32 r n n n r 480 640 1024 1024 640 640 bias,relu,clip -f f32 r n n n r 64 800 320 320 800 800 none -f f32 f32 r n n n r 64 800 320 320 800 800 bias,relu,clip -f f32 r n n n r 64 768 512 512 768 768 none -f f32 f32 r n n n r 64 768 512 512 768 768 bias,relu,clip -f f32 r n n n r 16 256 512 512 256 256 none -f f32 f32 r n n n r 16 256 512 512 256 256 bias,relu,clip -f f32 r n n n r 128 128 128 128 128 128 none -f f32 f32 r n n n r 128 128 128 128 128 128 bias,relu,clip -f f32 r n n n r 256 512 256 256 512 512 none -f f32 f32 r n n n r 256 512 256 256 512 512 bias,relu,clip -f f32 r n n n r 1024 1024 1024 1024 1024 1024 none -f f32 f32 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip -f f32 r n n n r 1024 32 256 256 32 32 none -f f32 f32 r n n n r 1024 32 256 256 32 32 bias,relu,clip -f f32 r n n n r 1024 64 512 512 64 64 none -f f32 f32 r n n n r 1024 64 512 512 64 64 bias,relu,clip -f f32 r n n n r 1024 256 32 32 256 256 none -f f32 f32 r n n n r 1024 256 32 32 256 256 bias,relu,clip -f f32 r n n n r 1024 512 64 64 512 512 none -f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip -f f32 r n n n r 512 32 256 256 32 32 none -f f32 f32 r n n n r 512 32 256 256 32 32 bias,relu,clip -f f32 r n n n r 512 768 512 512 768 768 none -f f32 f32 r n n n r 512 768 512 512 768 768 bias,relu,clip -f f32 r n n n r 512 256 32 32 256 256 none -f f32 f32 r n n n r 512 256 32 32 256 256 bias,relu,clip -f f32 r n n n r 512 512 64 64 512 512 none -f f32 f32 r n n n r 512 512 64 64 512 512 bias,relu,clip -f f32 r n n n r 512 256 768 768 256 256 none -f f32 f32 r n n n r 512 256 768 768 256 256 bias,relu,clip -f f32 r n n n r 768 768 1024 1024 768 768 none -f f32 f32 r n n n r 768 768 1024 1024 768 768 bias,relu,clip -f f32 r n n n r 768 768 768 768 768 768 none -f f32 f32 r n n n r 768 768 768 768 768 768 bias,relu,clip -f f32 r n n n r 2048 2048 2048 2048 2048 2048 none -f f32 f32 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip -f f32 r n n n r 4096 4096 4096 4096 4096 4096 none -f f32 f32 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip -f f32 r n n n r 2048 1024 1024 1024 1024 1024 none -f f32 f32 r n n n r 2048 1024 1024 1024 1024 1024 bias,relu,clip -f f32 r n n n r 2048 4096 1024 1024 4096 4096 none -f f32 f32 r n n n r 2048 4096 1024 1024 4096 4096 bias,relu,clip -f f32 r n n n r 2048 1024 4096 4096 1024 1024 none -f f32 f32 r n n n r 2048 1024 4096 4096 1024 1024 bias,relu,clip -f f32 r n n n r 2048 1024 2 2 1024 1024 none -f f32 f32 r n n n r 2048 1024 2 2 1024 1024 bias,relu,clip -f f32 r n n n r 128 1024 1024 1024 1024 1024 none -f f32 f32 r n n n r 128 1024 1024 1024 1024 1024 bias,relu,clip -f f32 r n n n r 1536 768 768 768 768 768 none -f f32 f32 r n n n r 1536 768 768 768 768 768 bias,relu,clip -f f32 r n n n r 1536 3072 768 768 3072 3072 none -f f32 f32 r n n n r 1536 3072 768 768 3072 3072 bias,relu,clip -f f32 r n n n r 1536 768 3072 3072 768 768 none -f f32 f32 r n n n r 1536 768 3072 3072 768 768 bias,relu,clip -f f32 r n n n r 1536 768 2 2 768 768 none -f f32 f32 r n n n r 1536 768 2 2 768 768 bias,relu,clip -f f32 r n n n r 128 768 768 768 768 768 none -f f32 f32 r n n n r 128 768 768 768 768 768 bias,relu,clip -f f32 r n n n r 1024 8 13 13 8 8 none -f f32 f32 r n n n r 1024 8 13 13 8 8 bias,relu,clip -f f32 r n n n r 1024 4 8 8 4 4 none -f f32 f32 r n n n r 1024 4 8 8 4 4 bias,relu,clip -f f32 r n n n r 1024 128 355 355 128 128 none -f f32 f32 r n n n r 1024 128 355 355 128 128 bias,relu,clip -f f32 r n n n r 1024 64 128 128 64 64 none -f f32 f32 r n n n r 1024 64 128 128 64 64 bias,relu,clip -f f32 r n n n r 1024 1 64 64 1 1 none -f f32 f32 r n n n r 1024 1 64 64 1 1 bias,relu,clip -f f32 r n n n r 480 1 256 256 1 1 none -f f32 f32 r n n n r 480 1 256 256 1 1 bias,relu,clip -f f32 r n n n r 480 256 512 512 256 256 none -f f32 f32 r n n n r 480 256 512 512 256 256 bias,relu,clip -f f32 r n n n r 480 1024 845 845 1024 1024 none -f f32 f32 r n n n r 480 1024 845 845 1024 1024 bias,relu,clip -f f32 r n n n r 480 512 1024 1024 512 512 none -f f32 f32 r n n n r 480 512 1024 1024 512 512 bias,relu,clip -f f32 r n n n r 10 17191 128 128 17191 17191 none -f f32 f32 r n n n r 10 17191 128 128 17191 17191 bias,relu,clip -f f32 r n n n r 10 512 256 256 512 512 none -f f32 f32 r n n n r 10 512 256 256 512 512 bias,relu,clip +r n t n r 288 12 6460 6460 6460 12 bf16s4f32of32:none +r n t n r 150 2048 6460 6460 6460 2048 bf16s4f32of32:none +r n n n r 1 10 2050 2050 20 20 bf16bf16f32obf16:none +r n n n r 482 690 2050 2050 690 690 f32f32f32of32:bias,matrix_mul +r n n n r 253 2048 660 660 2048 2048 bf16bf16f32of32:matrix_mul,clip +c n n n p 100 200 300 100 300 100 f32f32f32of32:matrix_mul,gelu_tanh +c t n n n 16 256 512 512 512 256 bf16bf16f32of32:matrix_mul +r n n n n 160 6424 2051 2051 6424 6424 *:bias,swish +r n n n r 74 512 515 515 512 512 *:none +r n n n r 253 2048 660 660 2048 2048 *:matrix_add +r n n n p 81 128 3 3 128 128 u8s8s32os32:bias,relu,clip +r n n n p 81 128 3 3 128 128 u8s8s32os8:bias,relu,clip +r n n n p 181 1280 3000 3000 1280 1280 *:bias,relu,clip,matrix_add +r n n n r 482 690 2050 2050 690 690 *:scale=scalar,zp=scalar,gelu_tanh,clip +r n n n r 482 690 2050 2050 690 690 *:scale=vector,zp=vector,bias,gelu_erf,clip +c n n n p 100 200 300 100 300 100 f32f32f32of32:bias,gelu_tanh,clip +c n n n p 100 200 300 100 300 100 f32f32f32of32:bias,gelu_erf,clip +r n n n r 144 1024 512 512 1024 1024 *:scale=vector,zp=scalar,relu,clip +r n n n r 144 1024 512 512 1024 1024 *:zp=vector,scale=scalar,relu,clip +r n n n r 128 128 128 128 128 128 *:bias,relu,clip +r n n n r 100 200 300 300 200 200 u8s8s16ou8:none +c t n n n 16 256 512 512 512 256 bf16bf16f32of32:none +r n n n r 144 6424 2090 2090 6424 6424 *:bias,swish +c n n n n 160 6400 2051 160 2051 160 bf16bf16f32obf16:bias,matrix_mul +c n n n n 160 6400 2051 160 2051 160 bf16bf16f32of32:bias,matrix_add +r n n n n 160 6424 2051 2051 6424 6424 *:bias,swish +r n n n r 74 512 515 515 512 512 *:none diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index bb70a087b2..f366cf3a97 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,68 +32,9 @@ */ -#include -#include -#include -#include -#include -#include -#include -#include +#include "bench_lpgemm_helpers.h" -#include "blis.h" - - -// Used to clip downscaled output, will be set in the main loop based -// on the accumulation and C data type. -int64_t DSCALE_CLIP_MIN = 0; -int64_t DSCALE_CLIP_MAX = 0; - -// Mode can be one of the follwoing: -// 1. p - performance, used for benchmarks. -// 2. a - accuracy, used to test accuracy/correctness. -// Default value is p, can be modified by passing command line arg. -char bench_mode = 'p'; - -int32_t global_n_repeat = 0; - -char global_dscale_out = 'n'; - -dim_t num_eltwise = 0; // To keep track of eltwise operations. - -#define _XSTR(str) #str -#define XSTR(str) _XSTR(str) - -#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype - -static inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) -{ - /*Set offset 2 to copy most significant 2 bytes of float - to convert float values to bf16 values*/ - memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); -} - -static inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) -{ - for (int i=0; i< size; i++) - { - float_to_bf16( ( array + i ), ( array_bf16 + i ) ); - } -} - - -static inline void bfloat16_to_float( bfloat16 bf16_val, float* float_val ) -{ - int32_t inter_temp = *( ( int16_t* ) &bf16_val ); - inter_temp = inter_temp << 16; - memcpy( float_val, &inter_temp, sizeof( int32_t ) ); -} - -#define CONVERT_TO_FLOAT(ctype) \ -static inline void GEN_FUNC_NAME(ctype,_to_float) ( ctype val, float* float_val ) \ -{ \ - *float_val = (float) val; \ -} \ +char global_pre_op = 'n'; CONVERT_TO_FLOAT(uint8_t) CONVERT_TO_FLOAT(int8_t) @@ -101,8 +42,6 @@ CONVERT_TO_FLOAT(int16_t) CONVERT_TO_FLOAT(float) CONVERT_TO_FLOAT(int32_t) - - /* Helper functions to print matrices when debugging */ void print_matrix_bfloat16 ( @@ -126,11 +65,11 @@ void print_matrix_bfloat16 } #define PRINT_MATRIX(ctype) \ -void print_matrix_## ctype ( ctype* a, int32_t m, int32_t n, int32_t rs, int32_t cs) \ +void print_matrix_## ctype ( ctype* a, dim_t m, dim_t n, dim_t rs, dim_t cs) \ { \ - for(int32_t i = 0; i < m; i++) \ + for(dim_t i = 0; i < m; i++) \ { \ - for(int32_t j = 0; j < n; j++) \ + for(dim_t j = 0; j < n; j++) \ { \ printf("%f ", (float) (*(a + i * ( rs ) + j * cs ) ) ); \ } \ @@ -144,95 +83,36 @@ PRINT_MATRIX(int16_t) PRINT_MATRIX(float) PRINT_MATRIX(int32_t) -void* lpgemm_malloc( int32_t size ) -{ - void* p; - // creating a dummy buffer of size 4 bytes in case - // size of the matrix is negative. - if( size <= 0 ) - { - p = malloc( 4 ); - return p; - } - - if( bench_mode == 'a' ) - { - p = malloc(size); - } - else - { - err_t err = BLIS_SUCCESS; - p = bli_malloc_user(size, &err); - } - if ( p == NULL ) - { - printf("Unable to allocate memory.\n"); - exit(1); - } - return p; -} - -void lpgemm_free( void* p ) -{ - if( p == NULL) - { - printf("Attempt to free null pointer\n"); - return; - } - - if( bench_mode == 'a' ) - { - free(p); - } - else - { - bli_free_user(p); - } -} - -#define GEN_FILL_ARRAY_FUNC(ctype) \ -void fill_array_ ## ctype ( void* arr, dim_t size ) \ -{ \ - if( size < 0 ) return; \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 5 ); \ - } \ -} \ - -GEN_FILL_ARRAY_FUNC(uint8_t) GEN_FILL_ARRAY_FUNC(int8_t) GEN_FILL_ARRAY_FUNC(int16_t) GEN_FILL_ARRAY_FUNC(float) GEN_FILL_ARRAY_FUNC(int32_t) -void fill_array_bfloat16( void* arr, dim_t size ) +void fill_array_uint8_t ( void* arr, dim_t size ) { - err_t bli_errors = BLIS_SUCCESS; if( size < 0 ) return; - float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size, &bli_errors ); + uint8_t* temp_arr = ( uint8_t* ) arr; for ( dim_t i = 0; i < size; ++i ) { - c_float[i] = i % 5; + temp_arr[i] = ( uint8_t )( rand() % 5 ); } - convert_float_arr_to_bf16( c_float, arr, size ); - if ( c_float != NULL ) +} + +void fill_array_int4_c_t( void* arr, dim_t size ) +{ + int8_t int4_c_t_values[8] = { 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF }; + //int8_t int4_c_t_values[8] = { 0x01, 0x23, 0x45, 0x67, 0x01, 0x23, 0x45, 0x67 }; + dim_t int4_c_t_size = ( size + 1 ) / 2; + if ( size < 0 ) return; + // Fill in pairs for in4_t since 4 bits/half byte access is not + // straight forward. + int8_t* temp_arr = ( int8_t* )arr; + for (dim_t i = 0; i < int4_c_t_size; ++i) { - bli_free_user( c_float ); + temp_arr[i] = int4_c_t_values[( rand() % 8 )]; } } -#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ -void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ -{ \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 20 ); \ - } \ -} \ - GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) GEN_FILL_ARRAY_POST_OPS_FUNC(float) @@ -265,61 +145,6 @@ void mat_mul_ ## BLAS_SFX \ b, ldb, op_b, \ beta, \ c, ldc, post_op ); \ - \ - /*dim_t MR = 6; \ - dim_t NR = 16; \ - \ - __m512i selector1; \ - __m512i all_zero = _mm512_setzero_epi32(); \ - __m512i c0; \ - __m512i c1; \ - __m512i c2; \ - __m512i c3; \ - __m512i c4; \ - __m512i c5; \ - \ - for ( dim_t i = 0; i < m; i += MR ) \ - { \ - if ( ( i + MR ) > m ) \ - { \ - break; \ - } \ - for ( dim_t j = 0; j < n; j += NR ) \ - { \ - if ( ( j + NR ) > n ) \ - { \ - break; \ - } \ - selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ - c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ - c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ - c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ - c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ - c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ - c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ - \ - c0 = _mm512_add_epi32( selector1, c0 ); \ - c1 = _mm512_add_epi32( selector1, c1 ); \ - c2 = _mm512_add_epi32( selector1, c2 ); \ - c3 = _mm512_add_epi32( selector1, c3 ); \ - c4 = _mm512_add_epi32( selector1, c4 ); \ - c5 = _mm512_add_epi32( selector1, c5 ); \ - \ - c0 = _mm512_max_epi32( all_zero, c0 ); \ - c1 = _mm512_max_epi32( all_zero, c1 ); \ - c2 = _mm512_max_epi32( all_zero, c2 ); \ - c3 = _mm512_max_epi32( all_zero, c3 ); \ - c4 = _mm512_max_epi32( all_zero, c4 ); \ - c5 = _mm512_max_epi32( all_zero, c5 ); \ - \ - _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ - _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ - _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ - _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ - _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ - _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ - } \ - } */\ } \ GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) @@ -334,6 +159,8 @@ GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16) double get_gflops ( @@ -358,10 +185,10 @@ void print_result dim_t lda, dim_t ldb, dim_t ldc, - double runtime + double gflops ) { - double gflops = get_gflops( m, n, k, runtime ); + //double gflops = get_gflops( m, n, k, runtime ); printf("%s transa:%c, transb:%c, m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ " Gops: %f, n_repeats: %d\n", msg, transa, transb, m, n, k, lda, ldb, ldc, gflops, n_repeats); @@ -390,17 +217,12 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ - double min_time_diff = DBL_MAX; \ + double dtime; \ + double dtime_save = DBL_MAX; \ +\ for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ { \ - if ( bench_mode == 'a' ) \ - { \ - int32_t size_C = ( ( stor_order == 'r') || ( stor_order == 'R' ) )? m * ldc : n * ldc; \ - GEN_FUNC_NAME(fill_array_,C_type)( c, ( size_C ) ); \ - } \ - \ - struct timespec tstart={0,0}, tend={0,0}; \ - clock_gettime(CLOCK_MONOTONIC, &tstart); \ + dtime = bli_clock(); \ \ GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ ( \ @@ -413,15 +235,12 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ post_op \ ); \ \ - clock_gettime(CLOCK_MONOTONIC, &tend); \ + dtime_save = bli_clock_min_diff( dtime_save, dtime ); \ \ - double diff = \ - ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ - ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ - min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ } \ + double gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); \ \ - print_result( XSTR(BLAS_SFX), n_repeats, transa, transb, m, n, k, lda, ldb, ldc, min_time_diff); \ + print_result( XSTR(BLAS_SFX), n_repeats, transa, transb, m, n, k, lda, ldb, ldc, gflops); \ } \ GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) @@ -436,16 +255,8 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) - -int max (int a, int b) -{ - return ( a > b ? a : b ); -} - -int min (int a, int b) -{ - return ( a < b ? a : b ); -} +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32) +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16) #define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ @@ -455,14 +266,26 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX dim_t j \ )\ { \ + dim_t j_scale = j; \ + if ( ( post_op->sum )->scale_factor_len == 1 ) \ + { \ + j_scale = 0; \ + } \ + \ + dim_t j_zp = j; \ + if ( ( post_op->sum )->zero_point_len == 1 ) \ + { \ + j_zp = 0; \ + } \ + \ ACCUM_type out_temp_accum = \ ( ACCUM_type )min( \ max( nearbyintf( ( SCALE_type )( temp_accum ) * \ - ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ) + \ - *( ( C_type* )post_op->sum.zero_point + j ), \ + ( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \ + *( ( C_type* )( post_op->sum )->zero_point + j_zp ), \ DSCALE_CLIP_MIN ), \ DSCALE_CLIP_MAX ); \ - return out_temp_accum; \ + return out_temp_accum; \ }\ GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8) @@ -478,7 +301,25 @@ static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16 dim_t j ) { - return temp_accum; + dim_t j_scale = j; + if ( ( post_op->sum )->scale_factor_len == 1 ) + { + j_scale = 0; + } + + dim_t j_zp = j; + if ( ( post_op->sum )->zero_point_len == 1 ) + { + j_zp = 0; + } + + float zp_float = 0.0; + bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ), + &zp_float ); + float out_temp_accum = ( temp_accum * + ( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) + + zp_float ); + return out_temp_accum; } #define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ @@ -498,31 +339,99 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ dim_t cs_c_ref, \ dim_t i, \ dim_t j, \ - dim_t k \ - )\ -{\ + dim_t k, \ + bool int4_testing, /* Workaround to enable int4 B matrix testing. */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ + ) \ +{ \ + ( void )int4_testing; \ + ( void ) pre_op; \ for ( dim_t p = 0; p < k; ++p) \ { \ temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ } \ -\ + \ temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ + ( alpha * temp_accum ); \ return temp_accum; \ -}\ +} \ GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) -GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) -GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) +#define GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ + (\ + A_type* a, \ + B_type* b, \ + C_type* c_ref, \ + ACCUM_type temp_accum,\ + ACCUM_type alpha, \ + ACCUM_type beta, \ + dim_t rs_a, \ + dim_t rs_b, \ + dim_t cs_a, \ + dim_t cs_b, \ + dim_t rs_c_ref, \ + dim_t cs_c_ref, \ + dim_t i, \ + dim_t j, \ + dim_t k, \ + bool int4_testing, /* Workaround to enable int4 B matrix testing. */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ + ) \ +{ \ + ( void ) pre_op; \ + if ( int4_testing == FALSE ) \ + { \ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ + *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ + } \ + } \ + else \ + { \ + for ( dim_t p = 0; p < k; ++p) \ + { \ + /* Get B matrix int4_t value and upscale it to int8_t. */ \ + dim_t b_inc = ( rs_b * p ) + ( cs_b * j ); \ + int8_t b_val = 0; \ + /* Even index will have data at low 4 bits, and odd at hi 4 bits. + * B matrix increments has to be halved to account for 4 bit + * traversal. */ \ + if ( ( b_inc % 2 ) != 0 ) \ + { \ + b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F; \ + } \ + else \ + { \ + b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F; \ + } \ + /* Signed scale. */ \ + if ( b_val & 0x08 ) \ + { \ + b_val = b_val | 0xF0; \ + } \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * b_val ); \ + } \ + } \ + \ + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ + + ( alpha * temp_accum ); \ + return temp_accum; \ +} \ + +GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) + static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 ( bfloat16* a, @@ -539,9 +448,13 @@ static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 dim_t cs_c_ref, dim_t i, dim_t j, - dim_t k + dim_t k, + bool int4_testing, /* Ignored for bf16 testing */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ ) { + ( void )int4_testing; + ( void ) pre_op; for ( dim_t p = 0; p < k; ++p) { float a_float, b_float; @@ -570,9 +483,13 @@ static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 dim_t cs_c_ref, dim_t i, dim_t j, - dim_t k + dim_t k, + bool int4_testing, /* Ignored for bf16 testing */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ ) { + ( void )int4_testing; + ( void ) pre_op; for ( dim_t p = 0; p < k; ++p) { float a_float, b_float; @@ -587,18 +504,151 @@ static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 return temp_accum; } -#define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \ -static inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \ - (\ - ACCUM_type temp_accum \ - )\ -{\ - float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ - ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ - (double)temp_accum ) ) ) ) ); \ - temp_accum = round (gelu_reference); \ - return temp_accum; \ -}\ +static inline float get_s4_to_f32_scale_val + ( + int8_t* b, + dim_t j, + dim_t b_inc, + aocl_pre_op* pre_op + ) +{ + float b_float = 0.0; + int8_t b_val = 0; + + /* Even index will have data at low 4 bits, and odd at hi 4 bits. + * B matrix increments has to be halved to account for 4 bit + * traversal. */ + if ( ( b_inc % 2 ) != 0 ) + { + b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F; + } + else + { + b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F; + } + + /* Signed scale. */ + if ( b_val & 0x08 ) + { + b_val = b_val | 0xF0; + } + + if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) ) + { + dim_t j_zp = j; + if ( ( pre_op->b_zp != NULL ) && + ( ( pre_op->b_zp )->zero_point_len == 1 ) ) + { + j_zp = 0; + } + dim_t j_scale = j; + if ( ( pre_op->b_scl != NULL ) && + ( ( pre_op->b_scl )->scale_factor_len == 1 ) ) + { + j_scale = 0; + } + + // Assuming only 1 scale and zp. + int8_t zp = 0; + if ( ( pre_op->b_zp != NULL ) && + ( ( pre_op->b_zp )->zero_point != NULL ) ) + { + zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp ); + } + + float scale_factor = 1.0; + if ( ( pre_op->b_scl != NULL ) && + ( ( pre_op->b_scl )->scale_factor != NULL ) ) + { + scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale ); + } + b_float = (float)( b_val - zp ) * scale_factor; + } + else + { + b_float = (float)( b_val); + } + + return b_float; +} + +static inline float mat_mul_accuracy_check_accum_bf16s4f32of32 + ( + bfloat16* a, + int8_t* b, + float* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k, + bool int4_testing, /* Ignored s4 implies int4 testing. */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ + ) +{ + ( void )int4_testing; + for ( dim_t p = 0; p < k; ++p) + { + float a_float, b_float; + bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float); + + /* Get B matrix int4_t value and upscale it to float. */ + dim_t b_inc = ( rs_b * p ) + ( cs_b * j ); + b_float = get_s4_to_f32_scale_val( b, j, b_inc, pre_op ); + + temp_accum += ( ( a_float ) * ( b_float ) ); + } + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) + + ( alpha * temp_accum ); + return temp_accum; +} + +static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16 + ( + bfloat16* a, + int8_t* b, + bfloat16* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k, + bool int4_testing, /* Ignored for bf16 testing */\ + aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \ + ) +{ + ( void )int4_testing; + for ( dim_t p = 0; p < k; ++p) + { + float a_float, b_float; + bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float ); + + /* Get B matrix int4_t value and upscale it to float. */ + dim_t b_inc = ( rs_b * p ) + ( cs_b * j ); + b_float = get_s4_to_f32_scale_val( b, j, b_inc, pre_op ); + + temp_accum += ( ( a_float ) * ( b_float ) ); + } + float c_ref_float; + bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float ); + temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); + + return temp_accum; +} GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8) GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8) @@ -610,32 +660,11 @@ GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32) GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8) GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16) -#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \ -static inline float GELU_TANH_post_op_ ## BLAS_SFX \ - (\ - float temp_accum \ - )\ -{\ - temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ - ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ - (double)temp_accum ) ) ) ) ); \ - return temp_accum; \ -}\ - GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32) GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32) GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16) - -#define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \ -static inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \ - (\ - ACCUM_type temp_accum \ - )\ -{\ - float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ - temp_accum = round (gelu_reference); \ - return temp_accum; \ -}\ +GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32) +GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16) GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8) GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8) @@ -647,29 +676,75 @@ GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32) GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8) GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16) -#define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \ -static inline float GELU_ERF_post_op_ ## BLAS_SFX \ - (\ - float temp_accum \ - )\ -{\ - temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ - return temp_accum; \ -}\ - GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32) GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32) GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16) - -#define GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(C_type, ACCUM_type) \ -void mat_mul_get_output_type_val ## ACCUM_type ## C_type \ - ( \ - C_type* out_temp_accum, \ - ACCUM_type* temp_accum \ - ) \ -{ \ - ( *out_temp_accum ) = ( C_type )( *temp_accum ); \ -} \ +GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32) +GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16) + +GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8) +GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8) +GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16) +GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8) +GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32) +GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8) +GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32) +GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8) +GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16) + +GEN_SWISH_POSTOP_FLOAT(f32f32f32of32) +GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32) +GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16) +GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32) +GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16) + +GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16) +GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16) + +GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8) +GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8) +GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16) +GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32) +GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32) +GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32) + +GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16) +GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16) + +GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8) +GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8) +GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16) +GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32) +GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32) +GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32) + +GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16) +GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16) + +GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8) +GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32) +GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8) +GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8) +GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16) +GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8) +GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32) +GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8) +GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16) +GEN_GET_BIAS_POST_OP_VAL(float,f32f32f32of32) +GEN_GET_BIAS_POST_OP_VAL(float,bf16bf16f32of32) +GEN_GET_BIAS_POST_OP_VAL(float,bf16s4f32of32) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t) @@ -678,15 +753,6 @@ GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int16_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float) -void mat_mul_get_output_type_valfloatbfloat16 - ( - bfloat16* out_temp_accum, - float* temp_accum - ) -{ - float_to_bf16( temp_accum, out_temp_accum ); -} - #define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ( \ @@ -707,7 +773,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ dim_t ldc, \ C_type* c_ref, \ dim_t ldc_ref, \ - aocl_post_op* post_op\ + aocl_post_op* post_op, \ + bool int4_testing /* Workaround to enable int4 B matrix testing. */ \ ) \ { \ dim_t rs_a, cs_a; \ @@ -765,6 +832,11 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ cs_c_ref = ldc_ref; \ } \ \ + aocl_pre_op* a_pre_op = NULL; \ + if ( post_op != NULL ) \ + { \ + a_pre_op = post_op->pre_ops; \ + } \ for ( dim_t i = 0; i < m; ++i ) \ { \ for ( dim_t j = 0; j < n; ++j ) \ @@ -773,7 +845,9 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ C_type out_temp_accum = 0; \ \ temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \ - (a,b,c_ref,temp_accum,alpha,beta,rs_a,rs_b,cs_a,cs_b,rs_c_ref,cs_c_ref,i,j,k); \ + (a, b, c_ref, temp_accum, alpha, beta,\ + rs_a, rs_b, cs_a, cs_b, rs_c_ref, cs_c_ref, i, j, k, \ + int4_testing, a_pre_op); \ \ if ( post_op != NULL ) \ { \ @@ -782,7 +856,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ if ( post_op->seq_vector[op_id] == BIAS ) \ { \ - temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ + temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \ + ( ( post_op->bias )->bias, j ); \ } \ else if ( post_op->seq_vector[op_id] == ELTWISE ) \ { \ @@ -807,6 +882,15 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\ ele_i += 1; \ } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + SWISH ) /* SiLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \ + (temp_accum, \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.alpha ) );\ + ele_i += 1; \ + } \ else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ RELU ) /* ReLU*/ \ { \ @@ -838,6 +922,32 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ (temp_accum, post_op, j); \ } \ + else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \ + { \ + dim_t rs_m = ( post_op->matrix_add )->ldm; \ + dim_t cs_m = 1; \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + cs_m = rs_m; \ + rs_m = 1; \ + } \ + temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \ + ( *( ( C_type* )( post_op->matrix_add )->matrix + \ + ( i * rs_m ) + ( j * cs_m ) ) ); \ + } \ + else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \ + { \ + dim_t rs_m = ( post_op->matrix_mul )->ldm; \ + dim_t cs_m = 1; \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + cs_m = rs_m; \ + rs_m = 1; \ + } \ + temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \ + ( *( ( C_type* )( post_op->matrix_mul )->matrix + \ + ( i * rs_m ) + ( j * cs_m ) ) ); \ + } \ else \ {} \ } \ @@ -848,7 +958,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ &out_temp_accum, &temp_accum \ ); \ \ - if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ + if ( ( ( *( c + ( rs_c * i ) + ( cs_c * j ) ) - out_temp_accum ) > 1.0E-5 ) || \ + ( ( out_temp_accum - *( c + ( rs_c * i ) + ( cs_c * j ) ) ) > 1.0E-5 ) ) \ { \ float comp_float, ref_float; \ GEN_FUNC_NAME(C_type,_to_float)(*( c + ( rs_c * i ) + ( cs_c * j ) ), &comp_float); \ @@ -861,8 +972,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ref_float, comp_float - ref_float); \ fflush( fout ); \ } \ - printf("failure, m: %ld, n: %ld, k: %ld, computed:%f, ref:%f, diff:%f\n", i, j, k, \ - comp_float, ref_float, comp_float-ref_float); \ + printf("failure, m_index: %ld, n_index: %ld, k: %ld, computed:%f, ref:%f," \ + "diff:%f\n", i, j, k, comp_float, ref_float, comp_float-ref_float); \ goto cleanup_acc; \ } \ } \ @@ -883,25 +994,35 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16) -#define GEN_MAT_MUL_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \ -aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BIAS_type,BLAS_SFX) \ +static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ ( \ dim_t m, \ dim_t n, \ - char* post_ops_str \ + char* post_ops_str, \ + char stor_order \ ) \ { \ + if ( ( ( post_ops_str == NULL ) || \ + ( strcmp( post_ops_str, "none" ) == 0 ) ) && \ + ( global_dscale_out == 'n' ) && ( global_pre_op == 'n' ) ) \ + { \ + return NULL; \ + } \ + \ aocl_post_op* post_ops = NULL; \ post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ \ - if ( ( post_ops == NULL ) && ( global_dscale_out == 'n' ) ) \ + if ( post_ops == NULL ) \ { \ return NULL; \ } \ \ - /* Only supporting 5 post ops at max for now.*/ \ - dim_t max_post_ops_seq_length = 5; \ + /* Only supporting 8 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 8; \ post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ malloc \ ( \ @@ -911,37 +1032,89 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ \ if ( post_ops->seq_vector == NULL ) \ { \ - free( post_ops ); \ - return NULL; \ + goto err_handler; \ } \ \ /* Parse post ops list.*/ \ dim_t cur_op_index = 0; \ /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ post_ops->eltwise = NULL; \ - post_ops->bias.bias = NULL; \ - post_ops->sum.scale_factor = NULL; \ - post_ops->sum.buff = NULL; \ - post_ops->sum.zero_point = NULL; \ - if ( post_ops_str != NULL ) \ + \ + /* Bench limitation: can only support 1 bias, but LPGEMM can support + * multiple bias post-ops. */ \ + post_ops->bias = NULL; \ + post_ops->bias = malloc( sizeof( aocl_post_op_bias ) ); \ + if ( post_ops->bias == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->bias )->bias = NULL; \ + \ + /* Bench limitation: can only support 1 scale, but LPGEMM can support + * multiple scale post-ops. */ \ + post_ops->sum = NULL; \ + post_ops->sum = malloc( sizeof( aocl_post_op_sum ) ); \ + if ( post_ops->sum == NULL ) \ { \ - char* ops_tok = strtok(post_ops_str, ", " ); \ - bool is_relu = FALSE; \ - bool is_param_relu = FALSE; \ - bool is_gelu_tanh = FALSE; \ - bool is_gelu_erf = FALSE; \ - bool is_clip = FALSE; \ - dim_t activator_idx = 0; \ - dim_t clip_idx = 0; \ + goto err_handler; \ + } \ + ( post_ops->sum )->scale_factor = NULL; \ + ( post_ops->sum )->buff = NULL; \ + ( post_ops->sum )->zero_point = NULL; \ + ( post_ops->sum )->scale_factor_len = 0; \ + ( post_ops->sum )->zero_point_len = 0; \ + \ + /* Bench limitation: can only support 1 matrix add, but LPGEMM can support + * multiple scale post-ops. */ \ + post_ops->matrix_add = NULL; \ + post_ops->matrix_add = malloc( sizeof( aocl_post_op_matrix_add ) ); \ + if ( post_ops->matrix_add == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->matrix_add )->matrix = NULL; \ + ( post_ops->matrix_add )->ldm = 0; \ +\ + /* Bench limitation: can only support 1 matrix mul, but LPGEMM can support + * multiple scale post-ops. */ \ + post_ops->matrix_mul = NULL; \ + post_ops->matrix_mul = malloc( sizeof( aocl_post_op_matrix_mul ) ); \ + if ( post_ops->matrix_mul == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->matrix_mul )->matrix = NULL; \ + ( post_ops->matrix_mul )->ldm = 0; \ + \ + bool is_bias = FALSE; \ + bool is_relu = FALSE; \ + bool is_param_relu = FALSE; \ + bool is_gelu_tanh = FALSE; \ + bool is_gelu_erf = FALSE; \ + bool is_swish = FALSE; \ + bool is_clip = FALSE; \ + bool is_scalar_scale = FALSE; \ + bool is_scalar_zp = FALSE; \ + bool is_matrix_add = FALSE; \ + bool is_matrix_mul = FALSE; \ + dim_t activator_idx = 0; \ + dim_t clip_idx = 0; \ + \ + /* Post-Ops string parser. */ \ + num_eltwise = 0; /* Global variable, zero out for definied behavior. */\ + if ( strcmp( post_ops_str, "none" ) != 0 ) \ + { \ + char* ops_tok = strtok(post_ops_str, ", =" ); \ \ /* Ensure only one activator is used as an eltwise post-op.*/ \ bool is_activator_set = FALSE; \ - num_eltwise = 0; \ while ( ops_tok ) \ { \ + str_tolower( ops_tok ); \ if ( strcmp( ops_tok, "bias" ) == 0 ) \ { \ post_ops->seq_vector[cur_op_index] = BIAS; \ + is_bias = TRUE; \ cur_op_index++; \ } \ else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \ @@ -964,6 +1137,16 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ activator_idx = cur_op_index; \ cur_op_index++; \ } \ + else if ( ( strcmp( ops_tok, "swish" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_swish = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \ ( is_activator_set == FALSE ) ) \ { \ @@ -992,50 +1175,83 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ clip_idx = cur_op_index; \ cur_op_index++; \ } \ - ops_tok = strtok( NULL, ", " ); \ - } \ + else if ( strcmp( ops_tok, "scale" ) == 0 ) \ + { \ + ops_tok = strtok( NULL, ", " ); \ + str_tolower( ops_tok ); \ + if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \ + ( strcmp( ops_tok, "s" ) == 0 ) ) \ + { \ + is_scalar_scale = TRUE; \ + } \ + } \ + else if ( strcmp( ops_tok, "zp" ) == 0 ) \ + { \ + ops_tok = strtok( NULL, ", " ); \ + str_tolower( ops_tok ); \ + if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \ + ( strcmp( ops_tok, "s" ) == 0 ) ) \ + { \ + is_scalar_zp = TRUE; \ + } \ + } \ + else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \ + is_matrix_add = TRUE; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \ + is_matrix_mul = TRUE; \ + cur_op_index++; \ + } \ \ - /* Allocate bias buffer, return early if alloc fails.*/ \ - post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ - if ( post_ops->bias.bias == NULL ) \ - { \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ + ops_tok = strtok( NULL, ", =" ); \ } \ - GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + } \ \ - post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ - if ( post_ops->eltwise == NULL ) \ + if ( is_bias == TRUE ) \ + { \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + ( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \ + if ( ( post_ops->bias )->bias == NULL ) \ { \ - free( post_ops->bias.bias ); \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ + goto err_handler; \ } \ + GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \ + } \ \ - if ( num_eltwise > 0 ) \ + if ( num_eltwise > 0 ) \ + { \ + if ( num_eltwise > 1 ) \ { \ - if ( num_eltwise > 1 ) \ + if ( activator_idx < clip_idx ) \ { \ - if ( activator_idx < clip_idx ) \ - { \ - activator_idx = 0; \ - clip_idx = 1; \ - } \ - else \ - { \ - activator_idx = 1; \ - clip_idx = 0; \ - } \ + activator_idx = 0; \ + clip_idx = 1; \ } \ else \ { \ - activator_idx = 0; \ - clip_idx = 0; \ + activator_idx = 1; \ + clip_idx = 0; \ } \ } \ - /* Only one of relu,prelu,gelu_tanh,gelu_erf allowed as an activator.*/ \ + else \ + { \ + activator_idx = 0; \ + clip_idx = 0; \ + } \ + \ + post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ + if ( post_ops->eltwise == NULL ) \ + { \ + goto err_handler; \ + } \ + \ + /* Only one of relu, prelu, swish, gelu_tanh, gelu_erf allowed as + * an activator. */ \ if ( is_relu == TRUE ) \ { \ ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ @@ -1048,11 +1264,30 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ { \ ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ ( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \ } \ + if ( is_swish == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ + *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \ + } \ else if ( is_gelu_tanh == TRUE ) \ { \ ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ @@ -1073,8 +1308,18 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ { \ ( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \ ( post_ops->eltwise + clip_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.beta = NULL; \ ( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ ( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \ + { \ + goto err_handler; \ + } \ *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \ *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 23 ); \ ( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \ @@ -1086,95 +1331,180 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ post_ops->seq_vector[cur_op_index] = SCALE; \ cur_op_index++; \ \ - post_ops->sum.is_power_of_2 = FALSE; \ + ( post_ops->sum )->is_power_of_2 = FALSE; \ if ( global_dscale_out == 'y' ) \ { \ + dim_t n_scale = n; \ + if ( is_scalar_scale == TRUE ) \ + { \ + n_scale = 1; \ + } \ + \ + dim_t n_zp = n; \ + if ( is_scalar_zp == TRUE ) \ + { \ + n_zp = 1; \ + } \ + \ /* Allocate scale buffer, return early if alloc fails.*/ \ - post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ - post_ops->sum.zero_point = malloc( n * sizeof( C_DSCALE_type ) ); \ - if ( ( post_ops->sum.scale_factor == NULL ) || \ - ( post_ops->sum.zero_point == NULL ) ) \ + ( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \ + if ( ( post_ops->sum )->scale_factor == NULL ) \ { \ - free ( post_ops->eltwise ); \ - free ( post_ops->bias.bias ); \ - free( post_ops->seq_vector ); \ - if ( post_ops->sum.zero_point != NULL ) \ - { \ - free( post_ops->sum.zero_point ); \ - } \ - if ( post_ops->sum.scale_factor != NULL ) \ - { \ - free( post_ops->sum.scale_factor ); \ - } \ - free( post_ops ); \ - return NULL; \ + goto err_handler; \ } \ + ( post_ops->sum )->zero_point = malloc( n_zp * sizeof( C_DSCALE_type ) ); \ + if ( ( post_ops->sum )->zero_point == NULL ) \ + { \ + goto err_handler; \ + } \ + \ /* Fill scale factor and zero points.*/ \ - DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )post_ops->sum.scale_factor; \ - C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )post_ops->sum.zero_point; \ - for ( dim_t i = 0; i < n; ++i ) \ + DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \ + for ( dim_t i = 0; i < n_scale; ++i ) \ + { \ + temp_dscale_ptr[i] = ( ( DSCALE_type )2 ); \ + } \ + ( post_ops->sum )->scale_factor_len = n_scale; \ + \ + C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \ + for ( dim_t i = 0; i < n_zp; ++i ) \ { \ - temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ - temp_dzero_point_ptr[i] = (C_DSCALE_type)( i % 126 ); \ + temp_dzero_point_ptr[i] = (C_DSCALE_type)( ( i + 9 ) % 126 ); \ } \ + ( post_ops->sum )->zero_point_len = n_zp; \ + } \ + } \ + \ + if ( is_matrix_add == TRUE ) \ + { \ + /* Allocate add matrix buffer, return early if alloc fails.*/ \ + dim_t ele_dsize = 0; \ + if ( global_dscale_out == 'y' ) \ + { \ + ele_dsize = sizeof( C_DSCALE_type ); \ + } \ + else \ + { \ + ele_dsize = sizeof( C_type ); \ + } \ + ( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \ + if ( ( post_ops->matrix_add )->matrix == NULL ) \ + { \ + goto err_handler; \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \ + } \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + ( post_ops->matrix_add )->ldm = m; \ + } \ + else \ + { \ + ( post_ops->matrix_add )->ldm = n; \ + } \ + } \ + \ + if ( is_matrix_mul == TRUE ) \ + { \ + /* Allocate mul matrix buffer, return early if alloc fails.*/ \ + dim_t ele_dsize = 0; \ + if ( global_dscale_out == 'y' ) \ + { \ + ele_dsize = sizeof( C_DSCALE_type ); \ + } \ + else \ + { \ + ele_dsize = sizeof( C_type ); \ + } \ + ( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \ + if ( ( post_ops->matrix_mul )->matrix == NULL ) \ + { \ + goto err_handler; \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \ + } \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + ( post_ops->matrix_mul )->ldm = m; \ + } \ + else \ + { \ + ( post_ops->matrix_mul )->ldm = n; \ } \ } \ \ post_ops->seq_length = cur_op_index; \ + \ + /* Setup the pre_ops struct */ \ + post_ops->pre_ops = NULL; \ + if ( global_pre_op == 'y' ) \ + { \ + post_ops->pre_ops = malloc( sizeof( aocl_pre_op ) ); \ + if ( post_ops->pre_ops == NULL ) { goto err_handler; } \ + \ + ( post_ops->pre_ops )->b_zp = malloc( sizeof( aocl_pre_op_zp ) ); \ + if ( ( post_ops->pre_ops )->b_zp == NULL ) { goto err_handler; } \ + \ + ( post_ops->pre_ops )->b_scl = malloc( sizeof( aocl_pre_op_sf ) ); \ + if ( ( post_ops->pre_ops )->b_scl == NULL ) { goto err_handler; } \ + \ + /* Only int8_t zero point supported in pre-ops. */ \ + /* Not handled in 4x64 bf16s4f32of32 kernel */ \ + ( ( post_ops->pre_ops )->b_zp )->zero_point = malloc( n * sizeof( int8_t ) ); \ + if ( ( ( post_ops->pre_ops )->b_zp )->zero_point == NULL ) { goto err_handler; } \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + ( ( int8_t* )( ( post_ops->pre_ops )->b_zp )->zero_point )[i] = ( int8_t )( 0 ); \ + } \ + ( ( post_ops->pre_ops )->b_zp )->zero_point_len = n; \ +\ + /* Only float scale factor supported in pre-ops. */ \ + ( ( post_ops->pre_ops )->b_scl )->scale_factor = malloc( n * sizeof( float ) ); \ + if ( ( ( post_ops->pre_ops )->b_scl )->scale_factor == NULL ) { goto err_handler; } \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + ( ( float* )( ( post_ops->pre_ops )->b_scl )->scale_factor )[i] = ( ( float )( ( i + 1 ) % 5 ) ); \ + } \ + ( ( post_ops->pre_ops )->b_scl )->scale_factor_len = n; \ + \ + ( post_ops->pre_ops )->seq_length = 1; \ + } \ \ return post_ops; \ + \ + err_handler: \ + lpgemm_destroy_post_ops_struct( post_ops ); \ + return NULL; \ } \ -GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,u8s8s16os16) -GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,u8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bf16bf16f32of32) -GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,f32f32f32of32) -GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,s8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,s8s8s16os16) - -void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) -{ - if ( post_ops == NULL ) - { - return; - } - - if ( post_ops->eltwise != NULL ) - { - for ( dim_t i = 0; i < num_eltwise; ++i ) - { - if ( ( post_ops->eltwise + i )->algo.alpha != NULL ) - { - free( ( post_ops->eltwise + i )->algo.alpha ); - } - if ( ( post_ops->eltwise + i )->algo.beta != NULL ) - { - free( ( post_ops->eltwise + i )->algo.beta ); - } - } - free( post_ops->eltwise ); - } - if ( post_ops->sum.scale_factor != NULL ) - { - free( post_ops->sum.scale_factor ); - } - if ( post_ops->sum.zero_point != NULL ) - { - free( post_ops->sum.zero_point ); - } - if ( post_ops->bias.bias != NULL ) - { - free( post_ops->bias.bias ); - } - if( post_ops->seq_vector != NULL ) - { - free( post_ops->seq_vector ); - } - - free( post_ops ); -} - -#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX) \ +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,f32f32f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16) + +// Hack to fix compiler errors. +#define GET_B_TYPE_bf16bf16f32of32 bfloat16 +#define GET_B_TYPE_u8s8s16os16 int8_t +#define GET_B_TYPE_u8s8s32os32 int8_t +#define GET_B_TYPE_f32f32f32of32 float +#define GET_B_TYPE_s8s8s32os32 int8_t +#define GET_B_TYPE_s8s8s16os16 int8_t + +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \ void mat_mul_bench_main_ ## BLAS_SFX \ ( \ FILE* fin, \ @@ -1184,13 +1514,14 @@ void mat_mul_bench_main_ ## BLAS_SFX \ char transb, \ char op_a, \ char op_b, \ - int32_t m, \ - int32_t n, \ - int32_t k, \ - int32_t stride_a, \ - int32_t stride_b, \ - int32_t stride_c, \ - char* post_ops_str \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + dim_t stride_a, \ + dim_t stride_b, \ + dim_t stride_c, \ + char* post_ops_str, \ + bool int4_testing /* Workaround to enable int4 B matrix testing. */\ ) \ { \ int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 )); \ @@ -1199,9 +1530,9 @@ void mat_mul_bench_main_ ## BLAS_SFX \ n_repeats = global_n_repeat; \ } \ \ - int32_t size_A = 0; \ - int32_t size_B = 0; \ - int32_t size_C = 0; \ + dim_t size_A = 0; \ + dim_t size_B = 0; \ + dim_t size_C = 0; \ if( ( stor_order == 'r' ) || ( stor_order == 'R' ) ) \ { \ size_A = ( ( transa == 'n' ) || ( transa == 'N' ) ) ? m * stride_a : k * stride_a; \ @@ -1218,7 +1549,14 @@ void mat_mul_bench_main_ ## BLAS_SFX \ GEN_FUNC_NAME(fill_array_,A_type)(a, size_A ); \ \ B_type* b = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \ - GEN_FUNC_NAME(fill_array_,B_type)(b, size_B ); \ + if ( int4_testing == FALSE ) \ + { \ + GEN_FUNC_NAME(fill_array_,B_type)(b, size_B ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_,int4_c_t)(b, size_B); \ + } \ \ C_type* c = ( C_type* ) lpgemm_malloc( sizeof( C_type ) * size_C ); \ \ @@ -1227,7 +1565,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ if ( bench_mode == 'a' ) \ { \ GEN_FUNC_NAME(fill_array_,C_type)( c, ( size_C ) ); \ - GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( size_C ) ); \ + memcpy(c_ref, c , (size_C * sizeof(C_type))); \ } \ else \ { \ @@ -1247,12 +1585,14 @@ void mat_mul_bench_main_ ## BLAS_SFX \ n_repeats = 1; \ alpha = 2; \ beta = 9; \ - } \ + } \ \ aocl_post_op* post_op = NULL; \ - if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ + if ( ( ( post_ops_str != NULL ) && \ + ( strcmp( post_ops_str, "none" ) != 0 ) ) || \ + ( global_dscale_out == 'y' ) || ( global_pre_op == 'y' ) ) \ { \ - post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str ); \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str, stor_order ); \ if ( post_op == NULL ) \ { \ printf(" post op struct allocation failure, returning.\n"); \ @@ -1276,12 +1616,29 @@ void mat_mul_bench_main_ ## BLAS_SFX \ } \ else if ( ( op_b == 'r' ) || ( op_b == 'R' ) ) \ { \ + B_type* b_reorder = NULL; \ /* Reorder B.*/ \ - siz_t b_reorder_buf_siz_req = \ - GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order, transb, 'B', k, n ); \ + if ( int4_testing == FALSE ) \ + { \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order, transb, 'B', k, n ); \ + \ + b_reorder = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( stor_order, transb, 'B', \ + ( GET_B_TYPE_ ## REORDER_SFX * )b, \ + ( GET_B_TYPE_ ## REORDER_SFX * )b_reorder, \ + k, n, stride_b ); \ + } \ + /* It has to be ensured, for now, only int4 testing takes else path. */ \ + else \ + { \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,INT4_REORDER_SFX)( stor_order, transb, 'B', k, n ); \ \ - B_type* b_reorder = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \ - GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( stor_order, transb, 'B', b, b_reorder, k, n, stride_b ); \ + b_reorder = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,INT4_REORDER_SFX)( stor_order, transb, 'B', \ + ( int8_t* )b, ( int8_t* )b_reorder, k, n, stride_b ); \ + } \ \ GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ ( \ @@ -1307,7 +1664,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ beta, \ c, stride_c, \ c_ref, stride_c, \ - post_op \ + post_op, int4_testing \ ); \ } \ \ @@ -1319,18 +1676,21 @@ void mat_mul_bench_main_ ## BLAS_SFX \ lpgemm_free( c_ref ); \ } \ -GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16,u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8,u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8,u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16,s8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8,s8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16,u8s8s16os16,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8,u8s8s16os16,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8,u8s8s16os16,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16,s8s8s16os16,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8,s8s8s16os16,u8s4s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32) + int main( int argc, char** argv ) { FILE* fin = NULL; @@ -1364,7 +1724,7 @@ int main( int argc, char** argv ) " 1. u8s8s32os32 -d s8 = u8s8s32os8.\n" \ " 2. u8s8s16os16 -d s8 = u8s8s16os8.\n" \ " 3. u8s8s16os16 -d u8 = u8s8s16ou8.\n" \ - " 4. bf16bf16f32obf32 -d bf16 = bf16bf16f32obf16.\n" \ + " 4. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \ " 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \ " 6. s8s8s16os16 -d s8 = s8s8s16os8.\n" \ " Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \ @@ -1373,25 +1733,38 @@ int main( int argc, char** argv ) } char* file_name = NULL; - char post_ops_str[50]; - char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. - char dscale_type_str[10]; + +#define GEMM_TYPE_STR_LEN 24 + char gemm_type_str[GEMM_TYPE_STR_LEN]; + +#define POST_OPS_STR_LEN 104 + char post_ops_str[POST_OPS_STR_LEN]; + char post_ops_str_dest[POST_OPS_STR_LEN]; //Strtok is used to parse, need to maintain a copy. + +#define OPS_INPUT_STR_LEN 128 + char ops_input_str[OPS_INPUT_STR_LEN]; // Parse CLI arguments. - opterr = 0; - int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + getopt_t state; + // Initialize the state for running bli_getopt(). Here, 0 is the + // initial value for opterr, which suppresses error messages. + bli_getopt_init_state( 0, &state ); + + int opt; + // Process all option arguments until we get a -1, which means we're done. + while( (opt = bli_getopt( argc, argv, "i:m:n:", &state )) != -1 ) { - switch ( opt_val ) + char opt_ch = ( char )opt; + switch( opt_ch ) { case 'i': - file_name = optarg; + file_name = state.optarg; break; case 'm': - bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + bench_mode = ( ( ( *state.optarg ) == 'a' ) || ( ( *state.optarg ) == 'p' ) ) ? ( *state.optarg ) : 'p'; break; case 'n': - global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + global_n_repeat = ( atoi( state.optarg ) > 0 ) ? atoi( state.optarg ) : 0; break; default: break; @@ -1424,15 +1797,14 @@ int main( int argc, char** argv ) fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); - char op_type_char; char op_a, op_b; char stor_order; char transa, transb; - int32_t m, n, k; - int32_t stride_a, stride_b, stride_c; + dim_t m, n, k; + dim_t stride_a, stride_b, stride_c; const dim_t len_list_omp_cores_for_testing = 2; - const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; + const dim_t list_omp_cores_for_testing[2] = { 1, 64 }; dim_t core_index = 0; bool can_run = TRUE; @@ -1466,220 +1838,268 @@ int main( int argc, char** argv ) } // Input format: data_type stor_type pack/reorder m n k lda ldb ldc - while ( fscanf( fin, "%c %s %c %c %c %c %c %d %d %d %d %d %d %s\n", - &op_type_char, dscale_type_str, &stor_order, &transa, &transb, &op_a, &op_b, &m, &n, &k, - &stride_a, &stride_b, &stride_c, post_ops_str ) == 14 ) + while ( fscanf( fin, "%c %c %c %c %c " INT_FS INT_FS INT_FS + INT_FS INT_FS INT_FS " %s\n", &stor_order, &transa, + &transb, &op_a, &op_b, &m, &n, &k, &stride_a, + &stride_b, &stride_c, ops_input_str ) == 12 ) { + char* ops_tok = strtok( ops_input_str, ":" ); + strncpy( gemm_type_str, ops_tok, GEMM_TYPE_STR_LEN - 1 ); + str_tolower( gemm_type_str ); \ + + ops_tok = strtok( NULL, "" ); + if ( ops_tok != NULL ) + { + strncpy( post_ops_str, ops_tok, POST_OPS_STR_LEN - 1 ); + } + else + { + strncpy( post_ops_str, "none", POST_OPS_STR_LEN - 1 ); + } + stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? stor_order : 'r'; - if ( strcmp( post_ops_str, "none" ) != 0 ) + if ( ( strcmp( gemm_type_str, "u8s8s32os32" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + // Copy the original post op str to a temp string buffer. + // Done so that strtok can be applied on the same (strtok + // is a destructive parser. + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "u8s8s32os8" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { - post_ops_str_dest = ( char* )malloc \ - ( ( strlen( post_ops_str) + 1 )* sizeof( char ) ); - strcpy( post_ops_str_dest, post_ops_str ); + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); } - - if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) + if ( ( strcmp( gemm_type_str, "u8s4s32os32" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { - if ( ( strcmp( dscale_type_str, "S32" ) == 0 ) || - ( strcmp( dscale_type_str, "s32" ) == 0 ) ) + // Copy the original post op str to a temp string buffer. + // Done so that strtok can be applied on the same (strtok + // is a destructive parser. + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + + if ( ( op_b != 'r' ) && ( op_b != 'R' ) ) + { + printf("Int4 B matrix only permitted if B reodering " + "is enabled.\n"); + } + else { - global_dscale_out = 'n'; GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) ( fin, fout, stor_order, transa, transb, op_a, op_b, m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest + post_ops_str_dest, TRUE ); } - else - { - if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || - ( strcmp( dscale_type_str, "s8" ) == 0 ) ) - { - global_dscale_out = 'y'; - DSCALE_CLIP_MIN = -128; - DSCALE_CLIP_MAX = +127; - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - printf("Downscale type not supported.\n"); - } - } } - else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) + if ( ( strcmp( gemm_type_str, "f32f32f32of32" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); global_dscale_out = 'n'; + global_pre_op = 'n'; GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) ( fin, fout, stor_order, transa, transb, op_a, op_b, m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "u8s8s16os16" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "u8s8s16os8" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "u8s8s16ou8" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + DSCALE_CLIP_MIN = 0; + DSCALE_CLIP_MAX = +255; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16ou8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "bf16bf16f32of32" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "bf16bf16f32obf16" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE ); } - else if ((op_type_char == 's') || (op_type_char == 'S')) + if ( strcmp( gemm_type_str, "bf16s4f32of32" ) == 0 ) { - if ( ( strcmp( dscale_type_str, "S16" ) == 0 ) || - ( strcmp( dscale_type_str, "s16" ) == 0 ) ) + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'y'; + + if ( ( op_b != 'r' ) && ( op_b != 'R' ) ) + { + printf("Int4 B matrix only permitted if B reodering " + "is enabled.\n"); + } + else { - global_dscale_out = 'n'; - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) + GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32of32) ( fin, fout, stor_order, transa, transb, op_a, op_b, m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest + post_ops_str_dest, TRUE ); } - else - { - if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || - ( strcmp( dscale_type_str, "s8" ) == 0 ) ) - { - global_dscale_out = 'y'; - DSCALE_CLIP_MIN = -128; - DSCALE_CLIP_MAX = +127; - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else if ( ( strcmp( dscale_type_str, "U8" ) == 0 ) || - ( strcmp( dscale_type_str, "u8" ) == 0 ) ) - { - global_dscale_out = 'y'; - DSCALE_CLIP_MIN = 0; - DSCALE_CLIP_MAX = +255; - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16ou8) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - printf("Downscale type not supported.\n"); - } - } } - else if ((op_type_char == 'b') || (op_type_char == 'B')) + if ( strcmp( gemm_type_str, "bf16s4f32obf16" ) == 0 ) { - if ( ( strcmp( dscale_type_str, "F32" ) == 0 ) || - ( strcmp( dscale_type_str, "f32" ) == 0 ) ) + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'y'; + + if ( ( op_b != 'r' ) && ( op_b != 'R' ) ) { - global_dscale_out = 'n'; - GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); + printf("Int4 B matrix only permitted if B reodering " + "is enabled.\n"); } - else if ( ( strcmp( dscale_type_str, "BF16" ) == 0 ) || - ( strcmp( dscale_type_str, "bf16" ) == 0 ) ) + else { - global_dscale_out = 'y'; - GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) + GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32obf16) ( fin, fout, stor_order, transa, transb, op_a, op_b, m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest + post_ops_str_dest, TRUE ); } - else - { - printf("Downscale type not supported.\n"); - } } - else if ( ( op_type_char == 'u' ) || ( op_type_char == 'U' ) ) + if ( ( strcmp( gemm_type_str, "s8s8s32os32" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { - if ( ( strcmp( dscale_type_str, "S32" ) == 0 ) || - ( strcmp( dscale_type_str, "s32" ) == 0 ) ) - { - global_dscale_out = 'n'; - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || - ( strcmp( dscale_type_str, "s8" ) == 0 ) ) - { - global_dscale_out = 'y'; - DSCALE_CLIP_MIN = -128; - DSCALE_CLIP_MAX = +127; - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - printf("Downscale type not supported.\n"); - } - } + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); } - else if ( ( op_type_char == 'v' ) || ( op_type_char == 'V' ) ) + if ( ( strcmp( gemm_type_str, "s8s8s32os8" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { - if ( ( strcmp( dscale_type_str, "S16" ) == 0 ) || - ( strcmp( dscale_type_str, "s16" ) == 0 ) ) - { - global_dscale_out = 'n'; - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || - ( strcmp( dscale_type_str, "s8" ) == 0 ) ) - { - global_dscale_out = 'y'; - DSCALE_CLIP_MIN = -128; - DSCALE_CLIP_MAX = +127; - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8) - ( - fin, fout, stor_order, transa, transb, op_a, op_b, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - printf("Downscale type not supported.\n"); - } - } + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); + } + if ( ( strcmp( gemm_type_str, "s8s8s16os16" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + global_pre_op = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); } - if ( strcmp( post_ops_str, "none" ) != 0 ) + if ( ( strcmp( gemm_type_str, "s8s8s16os8" ) == 0 ) || + ( strcmp( gemm_type_str, "*" ) == 0 ) ) { - strcpy( post_ops_str_dest, post_ops_str ); + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + global_pre_op = 'n'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest, FALSE + ); } } } - if ( post_ops_str_dest != NULL ) - { - free( post_ops_str_dest ); - } if ( fin ) { fclose( fin ); diff --git a/bench/bench_aocl_gemm/bench_lpgemm_eltwise_ops.c b/bench/bench_aocl_gemm/bench_lpgemm_eltwise_ops.c new file mode 100644 index 0000000000..8f7811f8fe --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm_eltwise_ops.c @@ -0,0 +1,1216 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "bench_lpgemm_helpers.h" + +GEN_FILL_ARRAY_FUNC(float) + +GEN_FILL_ARRAY_POST_OPS_FUNC(float) + +CONVERT_TO_FLOAT(float) + +void print_result + ( + const char* msg, + int32_t n_repeats, + char transa, + char transb, + dim_t m, + dim_t n, + dim_t lda, + dim_t ldb, + double gflops + ) +{ + printf("%s transa:%c, transb:%c, m: %ld, n: %ld, lda: %ld, ldb: %ld" \ + " Gops: %f, n_repeats: %d\n", + msg, transa, transb, m, n, lda, ldb, gflops, n_repeats); +} + +#define GEN_ELTWISE_OPS_GET_TEMP_ACCUM(A_type,ACCUM_type,LP_SFX) \ +ACCUM_type eltwise_ops_get_temp_accum_ ## LP_SFX \ + ( \ + A_type* a, \ + dim_t rs_a, \ + dim_t cs_a, \ + dim_t i, \ + dim_t j \ + ) \ +{ \ + float a_float; \ + bfloat16_to_float( *( a + ( i * rs_a ) + ( j * cs_a ) ), &a_float ); \ + return a_float; \ +} \ + +GEN_ELTWISE_OPS_GET_TEMP_ACCUM(bfloat16,float,bf16of32) +GEN_ELTWISE_OPS_GET_TEMP_ACCUM(bfloat16,float,bf16obf16) + +#define GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(A_type,ACCUM_type,LP_SFX) \ +ACCUM_type eltwise_ops_get_temp_accum_ ## LP_SFX \ + ( \ + A_type* a, \ + dim_t rs_a, \ + dim_t cs_a, \ + dim_t i, \ + dim_t j \ + ) \ +{ \ + float a_float = *( a + ( i * rs_a ) + ( j * cs_a ) ); \ + return a_float; \ +} \ + +GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32of32) + +GEN_GET_BIAS_POST_OP_VAL(float,bf16of32) +GEN_GET_BIAS_POST_OP_VAL_BF16(bf16obf16) +GEN_GET_BIAS_POST_OP_VAL(float,f32of32) + +GEN_GELU_TANH_POSTOP_FLOAT(bf16of32) +GEN_GELU_TANH_POSTOP_FLOAT(bf16obf16) +GEN_GELU_TANH_POSTOP_FLOAT(f32of32) + +GEN_GELU_ERF_POSTOP_FLOAT(bf16of32) +GEN_GELU_ERF_POSTOP_FLOAT(bf16obf16) +GEN_GELU_ERF_POSTOP_FLOAT(f32of32) + +GEN_SWISH_POSTOP_FLOAT(bf16of32) +GEN_SWISH_POSTOP_FLOAT(bf16obf16) +GEN_SWISH_POSTOP_FLOAT(f32of32) + +static inline float eltwise_ops_accuracy_check_downscale_bf16of32 + ( + float temp_accum, + aocl_post_op* post_op, + dim_t j + ) +{ + dim_t j_scale = j; + if ( ( post_op->sum )->scale_factor_len == 1 ) + { + j_scale = 0; + } + + dim_t j_zp = j; + if ( ( post_op->sum )->zero_point_len == 1 ) + { + j_zp = 0; + } + + float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp ); + float out_temp_accum = ( temp_accum * + ( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) + + zp_float ); + return out_temp_accum; +} + +static inline float eltwise_ops_accuracy_check_downscale_bf16obf16 + ( + float temp_accum, + aocl_post_op* post_op, + dim_t j + ) +{ + dim_t j_scale = j; + if ( ( post_op->sum )->scale_factor_len == 1 ) + { + j_scale = 0; + } + + dim_t j_zp = j; + if ( ( post_op->sum )->zero_point_len == 1 ) + { + j_zp = 0; + } + + float zp_float = 0.0; + bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ), + &zp_float ); + float out_temp_accum = ( temp_accum * + ( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) + + zp_float ); + return out_temp_accum; +} + +static inline float eltwise_ops_accuracy_check_downscale_f32of32 + ( + float temp_accum, + aocl_post_op* post_op, + dim_t j + ) +{ + dim_t j_scale = j; + if ( ( post_op->sum )->scale_factor_len == 1 ) + { + j_scale = 0; + } + + dim_t j_zp = j; + if ( ( post_op->sum )->zero_point_len == 1 ) + { + j_zp = 0; + } + + float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp ); + float out_temp_accum = ( temp_accum * + ( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) + + zp_float ); + return out_temp_accum; +} + +GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16of32) +GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16obf16) +GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32of32) + +GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16of32) +GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16obf16) +GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32of32) + +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float) + +#define GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(A_type,B_type,ACCUM_type,LP_SFX) \ +void eltwise_ops_accuracy_check_driver_ ## LP_SFX \ + ( \ + FILE* fout, \ + const char stor_order, \ + char transa, \ + char transb, \ + dim_t m, \ + dim_t n, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + aocl_post_op* post_op \ + ) \ +{ \ + dim_t rs_a, cs_a; \ + if( ( transa == 'n' ) || ( transa == 'N' ) ) \ + { \ + rs_a = lda; \ + cs_a = 1; \ + } \ + else \ + { \ + rs_a = 1; \ + cs_a = lda; \ + } \ + dim_t rs_b, cs_b; \ + if( ( transb == 'n' ) || ( transb == 'N' ) ) \ + { \ + rs_b = ldb; \ + cs_b = 1; \ + } \ + else \ + { \ + rs_b = 1; \ + cs_b = ldb; \ + } \ + \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + if( transa == 'n' || transa == 'N') \ + { \ + rs_a = 1; \ + cs_a = lda; \ + } \ + else \ + { \ + rs_a = lda; \ + cs_a = 1; \ + } \ + if( ( transb == 'n' ) || ( transb == 'N' ) ) \ + { \ + rs_b = 1; \ + cs_b = ldb; \ + } \ + else \ + { \ + rs_b = ldb; \ + cs_b = 1; \ + } \ + } \ + \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ACCUM_type temp_accum = 0; \ + B_type out_temp_accum = 0; \ + \ + temp_accum = GEN_FUNC_NAME(eltwise_ops_get_temp_accum_,LP_SFX) \ + ( a, rs_a, cs_a, i, j ); \ +\ + if ( post_op != NULL ) \ + { \ + dim_t ele_i = 0; \ + for ( dim_t op_id = 0; op_id < post_op->seq_length; ++op_id ) \ + { \ + if ( post_op->seq_vector[op_id] == BIAS ) \ + { \ + temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,LP_SFX) \ + ( ( post_op->bias )->bias, j ); \ + } \ + else if ( post_op->seq_vector[op_id] == ELTWISE ) \ + { \ + if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + PRELU ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * \ + *( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_TANH ) /* TANH GeLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,LP_SFX) (temp_accum);\ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_ERF ) /* ERF GeLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,LP_SFX) (temp_accum);\ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + SWISH ) /* SiLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(SWISH_post_op_,LP_SFX) \ + (temp_accum, \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.alpha ) );\ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + RELU ) /* ReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + CLIP ) /* CLIP*/ \ + { \ + temp_accum = \ + min \ + ( \ + max \ + ( \ + temp_accum, \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.alpha ) \ + ), \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.beta) \ + ); \ + ele_i += 1; \ + } \ + else \ + {} \ + } \ + else if ( post_op->seq_vector[op_id] == SCALE ) \ + { \ + temp_accum = GEN_FUNC_NAME(eltwise_ops_accuracy_check_downscale_,LP_SFX) \ + (temp_accum, post_op, j); \ + } \ + else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \ + { \ + dim_t rs_m = ( post_op->matrix_add )->ldm; \ + dim_t cs_m = 1; \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + cs_m = rs_m; \ + rs_m = 1; \ + } \ + temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,LP_SFX) \ + ( *( ( B_type* )( post_op->matrix_add )->matrix + \ + ( i * rs_m ) + ( j * cs_m ) ) ); \ + } \ + else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \ + { \ + dim_t rs_m = ( post_op->matrix_mul )->ldm; \ + dim_t cs_m = 1; \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + cs_m = rs_m; \ + rs_m = 1; \ + } \ + temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \ + ( *( ( B_type* )( post_op->matrix_mul )->matrix + \ + ( i * rs_m ) + ( j * cs_m ) ) ); \ + } \ + else \ + {} \ + } \ + } \ + /* Need to convert to downscaled type if required.*/ \ + mat_mul_get_output_type_val ## ACCUM_type ## B_type \ + ( \ + &out_temp_accum, &temp_accum \ + ); \ + \ + if ( *( b + ( rs_b * i ) + ( cs_b * j ) ) != out_temp_accum ) \ + { \ + float comp_float, ref_float; \ + GEN_FUNC_NAME(B_type,_to_float)(*( b + ( rs_b * i ) + ( cs_b * j ) ), &comp_float); \ + GEN_FUNC_NAME(B_type,_to_float)(out_temp_accum, &ref_float); \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input m: %ld, n: %ld," \ + " lda: %ld, ldb: %ld, computed:%f, ref:%f, diff:%f\n", \ + XSTR(LP_SFX), m, n, lda, ldb, comp_float, \ + ref_float, comp_float - ref_float); \ + fflush( fout ); \ + } \ + printf("failure, m: %ld, n: %ld, computed:%f, ref:%f, diff:%f\n", i, j, \ + comp_float, ref_float, comp_float-ref_float); \ + goto cleanup_acc; \ + } \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,float,float,bf16of32) +GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,bf16obf16) +GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,float,float,f32of32) + +#define GEN_ELTWISE_OPS_BENCH_DRV_FUNC(A_type,B_type,LP_SFX) \ +void eltwise_ops_bench_driver_ ## LP_SFX \ + ( \ + char stor_order, \ + char transa, \ + char transb, \ + int32_t n_repeats, \ + dim_t m, \ + dim_t n, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + aocl_post_op* post_op \ + ) \ +{ \ + double dtime; \ + double dtime_save = DBL_MAX; \ +\ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + dtime = bli_clock(); \ + \ + GEN_FUNC_NAME(aocl_gemm_eltwise_ops_,LP_SFX) \ + ( \ + stor_order, transa, transb, \ + m, n, \ + a, lda, \ + b, ldb, \ + post_op \ + ); \ + \ + dtime_save = bli_clock_min_diff( dtime_save, dtime ); \ + \ + } \ + double gflops = ( m * n ) / ( dtime_save * 1.0e9 ); \ + \ + print_result( XSTR(LP_SFX), n_repeats, transa, transb, m, n, lda, ldb, gflops); \ +} \ + +GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,float,bf16of32) +GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,bfloat16,bf16obf16) +GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,float,f32of32) + +#define GEN_ELTWISE_OPS_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \ +static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ + ( \ + dim_t m, \ + dim_t n, \ + char* post_ops_str, \ + char stor_order \ + ) \ +{ \ + if ( ( ( post_ops_str == NULL ) || \ + ( strcmp( post_ops_str, "none" ) == 0 ) ) && \ + ( global_dscale_out == 'n' ) ) \ + { \ + return NULL; \ + } \ + \ + aocl_post_op* post_ops = NULL; \ + post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ + \ + if ( post_ops == NULL ) \ + { \ + return NULL; \ + } \ + \ + /* Only supporting 8 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 8; \ + post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ + malloc \ + ( \ + max_post_ops_seq_length * \ + sizeof( AOCL_POST_OP_TYPE ) \ + ); \ + \ + if ( post_ops->seq_vector == NULL ) \ + { \ + goto err_handler; \ + } \ + \ + /* Parse post ops list.*/ \ + dim_t cur_op_index = 0; \ + /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ + post_ops->eltwise = NULL; \ + \ + /* Bench limitation: can only support 1 bias, but LPGEMM can support + * multiple scale post-ops. */ \ + post_ops->bias = NULL; \ + post_ops->bias = malloc( sizeof( aocl_post_op_bias ) ); \ + if ( post_ops->bias == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->bias )->bias = NULL; \ + \ + /* Bench limitation: can only support 1 scale, but LPGEMM can support + * multiple scale post-ops. */ \ + post_ops->sum = NULL; \ + post_ops->sum = malloc( sizeof( aocl_post_op_sum ) ); \ + if ( post_ops->sum == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->sum )->scale_factor = NULL; \ + ( post_ops->sum )->buff = NULL; \ + ( post_ops->sum )->zero_point = NULL; \ + ( post_ops->sum )->scale_factor_len = 0; \ + ( post_ops->sum )->zero_point_len = 0; \ + \ + /* Bench limitation: can only support 1 matrix add, but LPGEMM can support + * multiple matrix add post-ops. */ \ + post_ops->matrix_add = NULL; \ + post_ops->matrix_add = malloc( sizeof( aocl_post_op_matrix_add ) ); \ + if ( post_ops->matrix_add == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->matrix_add )->matrix = NULL; \ + ( post_ops->matrix_add )->ldm = 0; \ +\ + /* Bench limitation: can only support 1 matrix mul, but LPGEMM can support + * multiple matrix mul post-ops. */ \ + post_ops->matrix_mul = NULL; \ + post_ops->matrix_mul = malloc( sizeof( aocl_post_op_matrix_mul ) ); \ + if ( post_ops->matrix_mul == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->matrix_mul )->matrix = NULL; \ + ( post_ops->matrix_mul )->ldm = 0; \ + \ + bool is_bias = FALSE; \ + bool is_relu = FALSE; \ + bool is_param_relu = FALSE; \ + bool is_gelu_tanh = FALSE; \ + bool is_gelu_erf = FALSE; \ + bool is_swish = FALSE; \ + bool is_clip = FALSE; \ + bool is_scalar_scale = FALSE; \ + bool is_scalar_zp = FALSE; \ + bool is_matrix_add = FALSE; \ + bool is_matrix_mul = FALSE; \ + dim_t activator_idx = 0; \ + dim_t clip_idx = 0; \ + \ + /* Post-Ops string parser. */ \ + num_eltwise = 0; /* Global variable, zero out for definied behavior. */\ + if ( strcmp( post_ops_str, "none" ) != 0 ) \ + { \ + char* ops_tok = strtok(post_ops_str, ", =" ); \ + \ + /* Ensure only one activator is used as an eltwise post-op.*/ \ + bool is_activator_set = FALSE; \ + while ( ops_tok ) \ + { \ + str_tolower( ops_tok ); \ + if ( strcmp( ops_tok, "bias" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + is_bias = TRUE; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "prelu" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_param_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "swish" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_swish = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_tanh = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_erf" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_erf = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "clip" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_clip = TRUE; \ + num_eltwise += 1; \ + clip_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "scale" ) == 0 ) \ + { \ + ops_tok = strtok( NULL, ", " ); \ + str_tolower( ops_tok ); \ + if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \ + ( strcmp( ops_tok, "s" ) == 0 ) ) \ + { \ + is_scalar_scale = TRUE; \ + } \ + } \ + else if ( strcmp( ops_tok, "zp" ) == 0 ) \ + { \ + ops_tok = strtok( NULL, ", " ); \ + str_tolower( ops_tok ); \ + if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \ + ( strcmp( ops_tok, "s" ) == 0 ) ) \ + { \ + is_scalar_zp = TRUE; \ + } \ + } \ + else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \ + is_matrix_add = TRUE; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \ + is_matrix_mul = TRUE; \ + cur_op_index++; \ + } \ + \ + ops_tok = strtok( NULL, ", =" ); \ + } \ + } \ + \ + if ( is_bias == TRUE ) \ + { \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + ( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \ + if ( ( post_ops->bias )->bias == NULL ) \ + { \ + goto err_handler; \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + GEN_FUNC_NAME(fill_array_post_ops_,C_DSCALE_type)( ( post_ops->bias )->bias, n ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \ + } \ + } \ + \ + if ( num_eltwise > 0 ) \ + { \ + if ( num_eltwise > 1 ) \ + { \ + if ( activator_idx < clip_idx ) \ + { \ + activator_idx = 0; \ + clip_idx = 1; \ + } \ + else \ + { \ + activator_idx = 1; \ + clip_idx = 0; \ + } \ + } \ + else \ + { \ + activator_idx = 0; \ + clip_idx = 0; \ + } \ + \ + post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ + if ( post_ops->eltwise == NULL ) \ + { \ + goto err_handler; \ + } \ + \ + /* Only one of relu, prelu, swish, gelu_tanh, gelu_erf allowed as + * an activator. */ \ + if ( is_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \ + } \ + else if ( is_param_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ + *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \ + } \ + if ( is_swish == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ + *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \ + } \ + else if ( is_gelu_tanh == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \ + } \ + else if ( is_gelu_erf == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \ + } \ + if ( is_clip == TRUE ) \ + { \ + ( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + clip_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( DSCALE_type ) ); \ + if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( DSCALE_type ) ); \ + if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \ + { \ + goto err_handler; \ + } \ + *( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( DSCALE_type ) ( -64 ); \ + *( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( DSCALE_type ) ( 23 ); \ + ( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \ + } \ + } \ + \ + if ( global_dscale_out == 'y' ) \ + { \ + post_ops->seq_vector[cur_op_index] = SCALE; \ + cur_op_index++; \ + \ + ( post_ops->sum )->is_power_of_2 = FALSE; \ + if ( global_dscale_out == 'y' ) \ + { \ + dim_t n_scale = n; \ + if ( is_scalar_scale == TRUE ) \ + { \ + n_scale = 1; \ + } \ + \ + dim_t n_zp = n; \ + if ( is_scalar_zp == TRUE ) \ + { \ + n_zp = 1; \ + } \ + \ + /* Allocate scale buffer, return early if alloc fails.*/ \ + ( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \ + if ( ( post_ops->sum )->scale_factor == NULL ) \ + { \ + goto err_handler; \ + } \ + ( post_ops->sum )->zero_point = malloc( n_zp * sizeof( C_DSCALE_type ) ); \ + if ( ( post_ops->sum )->zero_point == NULL ) \ + { \ + goto err_handler; \ + } \ + \ + /* Fill scale factor and zero points.*/ \ + DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \ + for ( dim_t i = 0; i < n_scale; ++i ) \ + { \ + temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ + } \ + ( post_ops->sum )->scale_factor_len = n_scale; \ + \ + C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \ + GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( temp_dzero_point_ptr, n_zp ); \ + ( post_ops->sum )->zero_point_len = n_zp; \ + } \ + } \ + \ + if ( is_matrix_add == TRUE ) \ + { \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + dim_t ele_dsize = 0; \ + if ( global_dscale_out == 'y' ) \ + { \ + ele_dsize = sizeof( C_DSCALE_type ); \ + } \ + else \ + { \ + ele_dsize = sizeof( C_type ); \ + } \ + ( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \ + if ( ( post_ops->matrix_add )->matrix == NULL ) \ + { \ + goto err_handler; \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \ + } \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + ( post_ops->matrix_add )->ldm = m; \ + } \ + else \ + { \ + ( post_ops->matrix_add )->ldm = n; \ + } \ + } \ + \ + if ( is_matrix_mul == TRUE ) \ + { \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + dim_t ele_dsize = 0; \ + if ( global_dscale_out == 'y' ) \ + { \ + ele_dsize = sizeof( C_DSCALE_type ); \ + } \ + else \ + { \ + ele_dsize = sizeof( C_type ); \ + } \ + ( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \ + if ( ( post_ops->matrix_mul )->matrix == NULL ) \ + { \ + goto err_handler; \ + } \ + if ( global_dscale_out == 'y' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \ + } \ + else \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \ + } \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + ( post_ops->matrix_mul )->ldm = m; \ + } \ + else \ + { \ + ( post_ops->matrix_mul )->ldm = n; \ + } \ + } \ + \ + post_ops->seq_length = cur_op_index; \ + \ + post_ops->pre_ops = NULL; \ + \ + return post_ops; \ + \ + err_handler: \ + lpgemm_destroy_post_ops_struct( post_ops ); \ + return NULL; \ +} \ + +GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,float,float,bf16of32) +GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,bfloat16,float,bf16obf16) +GEN_ELTWISE_OPS_POST_OPS_CREATOR(float,float,float,f32of32) + +#define GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(A_type, B_type, LP_SFX) \ +void eltwise_ops_bench_main_ ## LP_SFX \ + ( \ + FILE* fout, \ + char stor_order, \ + char transa, \ + char transb, \ + int32_t m, \ + int32_t n, \ + int32_t stride_a, \ + int32_t stride_b, \ + char* post_ops_str \ + ) \ +{ \ + int32_t n_repeats = bli_max( 30, bli_min( ( 3e10 / ( ( int64_t )m * n ) ), 1000 ) ); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + int32_t size_A = 0; \ + int32_t size_B = 0; \ + if( ( stor_order == 'r' ) || ( stor_order == 'R' ) ) \ + { \ + size_A = ( ( transa == 'n' ) || ( transa == 'N' ) ) ? m * stride_a : n * stride_a; \ + size_B = ( ( transb == 'n' ) || ( transb == 'N' ) ) ? m * stride_b : n * stride_b; \ + } \ + else \ + { \ + size_A = ( ( transa == 'n' ) || ( transa == 'N' ) ) ? n * stride_a : m * stride_a; \ + size_B = ( ( transb == 'n' ) || ( transb == 'N' ) ) ? n * stride_b : m * stride_b; \ + } \ + \ + A_type* a = ( A_type* ) lpgemm_malloc( sizeof( A_type ) * size_A ); \ + GEN_FUNC_NAME(fill_array_,A_type)(a, size_A ); \ + \ + B_type* b = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \ + memset( ( void* ) b, 0, sizeof( B_type ) * size_B ); \ + \ + if ( bench_mode == 'a' ) \ + { \ + n_repeats = 1; \ + } \ + \ + aocl_post_op* post_op = NULL; \ + if ( ( ( post_ops_str != NULL ) && \ + ( strcmp( post_ops_str, "none" ) != 0 ) ) || \ + ( global_dscale_out == 'y' ) ) \ + { \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, post_ops_str, stor_order ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ + \ + GEN_FUNC_NAME(eltwise_ops_bench_driver_,LP_SFX) \ + ( \ + stor_order, transa, transb, n_repeats, \ + m, n, \ + a, stride_a, \ + b, stride_b, \ + post_op \ + ); \ + \ + if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(eltwise_ops_accuracy_check_driver_,LP_SFX) \ + ( \ + fout, stor_order, transa, transb, \ + m, n,\ + a, stride_a, \ + b, stride_b, \ + post_op \ + ); \ + } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ + \ + lpgemm_free( a ); \ + lpgemm_free( b ); \ +} \ + +GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,float,bf16of32) +GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,bfloat16,bf16obf16) +GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,float,f32of32) + +int main( int argc, char** argv ) +{ + FILE* fin = NULL; + if ( argc < 5 ) + { + printf + ( + "Usage: ./bench_lpgemm_eltwise_ops -i input.txt -m mode < -n 100 -o op1,op2 >\n" \ + "--Mode is either a or p.\n" \ + "\ta is used for accuracy testing.\n" \ + "\tp is used for performance benchmarking.\n" \ + "--n_repeats can be set optionally using -n arg.\n" \ + "--Post ops can be executed optionaly by providing a coma separated\n" \ + " list of post-ops after -o arg. Following post-ops are supported:\n" \ + " 1. bias\n" \ + " 2. 4 activators\n" \ + " a. relu\n" \ + " b. prelu\n" \ + " c. gelu_tanh\n" \ + " d. gelu_erf\n" \ + " 3.clip\n" \ + " Atleast one post-op needs to be specified if the -o arg is used.\n" \ + " eg: -o gelu_tanh; -o bias,relu ; -o clip,prelu,bias.\n" \ + " It is to be noted only one activator can be used at a time.\n" \ + " If more than one activator is used, only the first activator is\n" \ + " applied and the other activators are ignored.\n" \ + " Example: ./bench_lpgemm_eltwise_ops -m a -n 2 -o bias,relu -i input.txt\n" \ + ); + exit( 1 ); + } + + char* file_name = NULL; + +#define ELTWISE_OPS_TYPE_STR_LEN 24 + char eltwise_ops_type_str[ELTWISE_OPS_TYPE_STR_LEN]; + +#define POST_OPS_STR_LEN 104 + char post_ops_str[POST_OPS_STR_LEN]; + char post_ops_str_dest[POST_OPS_STR_LEN]; //Strtok is used to parse, need to maintain a copy. + +#define OPS_INPUT_STR_LEN 128 + char ops_input_str[OPS_INPUT_STR_LEN]; + + // Parse CLI arguments. + getopt_t state; + // Initialize the state for running bli_getopt(). Here, 0 is the + // initial value for opterr, which suppresses error messages. + bli_getopt_init_state( 0, &state ); + + int opt; + // Process all option arguments until we get a -1, which means we're done. + while( (opt = bli_getopt( argc, argv, "i:m:n:", &state )) != -1 ) + { + char opt_ch = ( char )opt; + switch( opt_ch ) + { + case 'i': + file_name = state.optarg; + break; + case 'm': + bench_mode = ( ( ( *state.optarg ) == 'a' ) || ( ( *state.optarg ) == 'p' ) ) ? ( *state.optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( state.optarg ) > 0 ) ? atoi( state.optarg ) : 0; + break; + default: + break; + } + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_eltwise_ops_accuracy_test_failures.txt", "w" ); + + char stor_order; + char transa, transb; + int32_t m, n; + int32_t stride_a, stride_b; + + const dim_t len_list_omp_cores_for_testing = 1; + const dim_t list_omp_cores_for_testing[1] = { 1 }; + + dim_t core_index = 0; + bool can_run = TRUE; + while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) + { + if ( bench_mode == 'p' ) + { + can_run = FALSE; + } + else if ( bench_mode == 'a' ) + { + // For accuracy testing, we test accuracy using multiple different + // number of cores. This helps uncover any bugs related to over + // subscription or varying thread factorizations. + // Set current number of cores. +#ifdef BLIS_ENABLE_OPENMP + omp_set_num_threads( list_omp_cores_for_testing[core_index] ); +#endif + printf( "Accuracy test using %ld threads.\n", + list_omp_cores_for_testing[core_index] ); + + core_index++; + if ( core_index < len_list_omp_cores_for_testing ) + { + can_run = TRUE; + } + else + { + can_run = FALSE; + } + } + + // Input format: data_type stor_type pack m n lda ldb + while ( fscanf( fin, "%c %c %c %d %d %d %d %s\n", + &stor_order, &transa, &transb, &m, &n, + &stride_a, &stride_b, ops_input_str ) == 8 ) + { + char* ops_tok = strtok( ops_input_str, ":" ); + strncpy( eltwise_ops_type_str, ops_tok, ELTWISE_OPS_TYPE_STR_LEN - 1 ); + str_tolower( eltwise_ops_type_str ); \ + + ops_tok = strtok( NULL, "" ); + if ( ops_tok != NULL ) + { + strncpy( post_ops_str, ops_tok, POST_OPS_STR_LEN - 1 ); + } + else + { + strncpy( post_ops_str, "none", POST_OPS_STR_LEN - 1 ); + } + + stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || + ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? + stor_order : 'r'; + + if ( ( strcmp( eltwise_ops_type_str, "bf16of32" ) == 0 ) || + ( strcmp( eltwise_ops_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + GEN_FUNC_NAME(eltwise_ops_bench_main_, bf16of32) + ( + fout, stor_order, transa, transb, + m, n, stride_a, stride_b, + post_ops_str_dest + ); + } + if ( ( strcmp( eltwise_ops_type_str, "bf16obf16" ) == 0 ) || + ( strcmp( eltwise_ops_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'y'; + GEN_FUNC_NAME(eltwise_ops_bench_main_, bf16obf16) + ( + fout, stor_order, transa, transb, + m, n, stride_a, stride_b, + post_ops_str_dest + ); + } + if ( ( strcmp( eltwise_ops_type_str, "f32of32" ) == 0 ) || + ( strcmp( eltwise_ops_type_str, "*" ) == 0 ) ) + { + strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN ); + global_dscale_out = 'n'; + GEN_FUNC_NAME(eltwise_ops_bench_main_, f32of32) + ( + fout, stor_order, transa, transb, + m, n, stride_a, stride_b, + post_ops_str_dest + ); + } + } + } + + if ( fin ) + { + fclose( fin ); + } + if ( fout ) + { + fclose( fout ); + } + return 0; +} diff --git a/bench/bench_aocl_gemm/bench_lpgemm_helpers.h b/bench/bench_aocl_gemm/bench_lpgemm_helpers.h new file mode 100644 index 0000000000..ded5aa9ca8 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm_helpers.h @@ -0,0 +1,439 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_BENCH_UTILS_H +#define LPGEMM_BENCH_UTILS_H + +#include +#include +#include +#include +#include +#include +#include + +#include "blis.h" + +// Used to clip downscaled output, will be set in the main loop based +// on the accumulation and C data type. +int64_t DSCALE_CLIP_MIN = 0; +int64_t DSCALE_CLIP_MAX = 0; + +// Mode can be one of the follwoing: +// 1. p - performance, used for benchmarks. +// 2. a - accuracy, used to test accuracy/correctness. +// Default value is p, can be modified by passing command line arg. +char bench_mode = 'p'; + +int32_t global_n_repeat = 0; + +char global_dscale_out = 'n'; + +dim_t num_eltwise = 0; // To keep track of eltwise operations. + +#define _XSTR(str) #str +#define XSTR(str) _XSTR(str) + +#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype + +// Inplace to lower func. +static inline void str_tolower( char* str ) +{ + for ( char* c = str; ( *c ) != '\0'; ++c ) + { *( c ) = tolower( *( c ) ); } +} + +#define CONVERT_TO_FLOAT(ctype) \ +static inline void GEN_FUNC_NAME(ctype,_to_float) ( ctype val, float* float_val ) \ +{ \ + *float_val = (float) val; \ +} \ + +static inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +{ + /*Set offset 2 to copy most significant 2 bytes of float + to convert float values to bf16 values*/ + memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); +} + +// Only works for little endian systems. +static inline void bfloat16_to_float( bfloat16 bf16_val, float* float_val ) +{ + int32_t inter_temp = *( ( int16_t* ) &bf16_val ); + inter_temp = inter_temp << 16; + memcpy( float_val, &inter_temp, sizeof( int32_t ) ); +} + +static inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, dim_t size ) +{ + for (dim_t i=0; i< size; i++) + { + float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + } +} + +static inline void* lpgemm_malloc( dim_t size ) +{ + void* p; + // creating a dummy buffer of size 4 bytes in case + // size of the matrix is negative. + if( size <= 0 ) + { + p = malloc( 4 ); + return p; + } + + if( bench_mode == 'a' ) + { + p = malloc(size); + } + else + { + err_t err = BLIS_SUCCESS; + p = bli_malloc_user(size, &err); + } + if ( p == NULL ) + { + printf("Unable to allocate memory.\n"); + exit(1); + } + return p; +} + +static inline void lpgemm_free( void* p ) +{ + if( p == NULL) + { + printf("Attempt to free null pointer\n"); + return; + } + + if( bench_mode == 'a' ) + { + free(p); + } + else + { + bli_free_user(p); + } +} + +/* Matrix fill helper macros. */ +#define GEN_FILL_ARRAY_FUNC(ctype) \ +static inline void fill_array_ ## ctype ( void* arr, dim_t size ) \ +{ \ + if( size < 0 ) return; \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( ( rand() % 11 ) - 5 ); \ + } \ +} \ + +static inline void fill_array_bfloat16( void* arr, dim_t size ) +{ + err_t bli_errors = BLIS_SUCCESS; + if( size < 0 ) return; + float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size, &bli_errors ); + for ( dim_t i = 0; i < size; ++i ) + { + c_float[i] = (rand() % 5 ); + } + convert_float_arr_to_bf16( c_float, arr, size ); + if ( c_float != NULL ) + { + bli_free_user( c_float ); + } +} + +#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ +static inline void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( rand() % 5 ); \ + } \ +} \ + +static inline void fill_array_post_ops_bfloat16( void* arr, dim_t size ) +{ + fill_array_bfloat16( arr, size ); +} + +/* POST-OPS Helper macros. */ + +/* Bias. */ +#define GEN_GET_BIAS_POST_OP_VAL_BF16(BLAS_SFX) \ +static inline float get_bias_post_op_val_ ## BLAS_SFX \ + ( \ + void* post_op_bias_ptr, \ + dim_t j \ + ) \ +{ \ + float ret_val = 0.0; \ + bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \ + return ret_val; \ +} \ + +#define GEN_GET_BIAS_POST_OP_VAL(ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \ + ( \ + void* post_op_bias_ptr, \ + dim_t j \ + ) \ +{ \ + return *( ( ACCUM_type* )post_op_bias_ptr + j ); \ +} \ + +/* GELU Tanh. */ +#define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \ + ( \ + ACCUM_type temp_accum \ + ) \ +{ \ + float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ +} \ + +#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \ +static inline float GELU_TANH_post_op_ ## BLAS_SFX \ + ( \ + float temp_accum \ + ) \ +{ \ + temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + return temp_accum; \ +} \ + +/* GELU Erf. */ +#define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \ + ( \ + ACCUM_type temp_accum \ + ) \ +{ \ + float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ +} \ + +#define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \ +static inline float GELU_ERF_post_op_ ## BLAS_SFX \ + ( \ + float temp_accum \ + ) \ +{ \ + temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + return temp_accum; \ +} \ + +/* SWISH. */ +#define GEN_SWISH_POSTOP_INT(ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \ + ( \ + ACCUM_type temp_accum, \ + ACCUM_type alpha \ + ) \ +{ \ + float swish_reference = ( temp_accum / ( 1 + \ + expf( ( double )alpha * temp_accum * -1 ) ) ); \ + temp_accum = round (swish_reference); \ + return temp_accum; \ +} \ + +#define GEN_SWISH_POSTOP_FLOAT(BLAS_SFX) \ +static inline float SWISH_post_op_ ## BLAS_SFX \ + ( \ + float temp_accum, \ + float alpha \ + ) \ +{ \ + temp_accum = ( temp_accum / ( 1 + \ + expf( ( double )alpha * temp_accum * -1 ) ) ); \ + return temp_accum; \ +} \ + +/* Matrix Add. */ +#define GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(C_type,BLAS_SFX) \ +static inline float get_matrix_add_post_op_val_ ## BLAS_SFX \ + ( \ + C_type val \ + ) \ +{ \ + float ret_val = 0.0; \ + bfloat16_to_float( val, &ret_val ); \ + return ret_val; \ +} \ + +#define GEN_GET_MATRIX_ADD_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \ + ( \ + C_type val \ + ) \ +{ \ + return (ACCUM_type) val; \ +} \ + +#define GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(C_type,BLAS_SFX) \ +static inline float get_matrix_mul_post_op_val_ ## BLAS_SFX \ + ( \ + C_type val \ + ) \ +{ \ + float ret_val = 0.0; \ + bfloat16_to_float( val, &ret_val ); \ + return ret_val; \ +} \ + +#define GEN_GET_MATRIX_MUL_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \ +static inline ACCUM_type get_matrix_mul_post_op_val_ ## BLAS_SFX \ + ( \ + C_type val \ + ) \ +{ \ + return (ACCUM_type) val; \ +} \ + +/* Final output type value getter. */ +#define GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(C_type, ACCUM_type) \ +static inline void mat_mul_get_output_type_val ## ACCUM_type ## C_type \ + ( \ + C_type* out_temp_accum, \ + ACCUM_type* temp_accum \ + ) \ +{ \ + ( *out_temp_accum ) = ( C_type )( *temp_accum ); \ +} \ + +static inline void mat_mul_get_output_type_valfloatbfloat16 + ( + bfloat16* out_temp_accum, + float* temp_accum + ) +{ + /* Fix for rounding bias. */ + uint32_t inter_temp; + memcpy( &inter_temp, temp_accum, sizeof( float ) ); + + /* Check if 16th bit is set */ + uint32_t tlsb = ( inter_temp & ( uint32_t )0x00010000 ) > 16; + + /* Adding rounding bias. */ + uint32_t rounded = inter_temp + ( uint32_t )0x00007FFF + tlsb; + memcpy( temp_accum, &rounded, sizeof( float ) ); + + float_to_bf16( temp_accum, out_temp_accum ); +} + +#ifndef WIN32 +static inline int max (int a, int b) +{ + return ( a > b ? a : b ); +} + +static inline int min (int a, int b) +{ + return ( a < b ? a : b ); +} +#endif + +static inline void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) +{ + if ( post_ops == NULL ) + { + return; + } + + if ( post_ops->eltwise != NULL ) + { + for ( dim_t i = 0; i < num_eltwise; ++i ) + { + free( ( post_ops->eltwise + i )->algo.alpha ); + free( ( post_ops->eltwise + i )->algo.beta ); + } + free( post_ops->eltwise ); + } + + if ( post_ops->matrix_add != NULL ) + { + free( ( post_ops->matrix_add )->matrix ); + free( post_ops->matrix_add ); + } + + if ( post_ops->sum != NULL ) + { + free( ( post_ops->sum )->scale_factor ); + free( ( post_ops->sum )->zero_point ); + free( post_ops->sum ); + } + + if ( post_ops->matrix_mul != NULL ) + { + free( ( post_ops->matrix_mul )->matrix ); + free( post_ops->matrix_mul ); + } + + if ( post_ops->bias != NULL ) + { + free( ( post_ops->bias )->bias ); + free( post_ops->bias ); + } + + if ( post_ops->pre_ops != NULL ) + { + if ( ( post_ops->pre_ops )->b_zp != NULL ) + { + free( ( ( post_ops->pre_ops )->b_zp )->zero_point ); + free( ( post_ops->pre_ops )->b_zp ); + } + if ( ( post_ops->pre_ops )->b_scl != NULL ) + { + free( ( ( post_ops->pre_ops )->b_scl )->scale_factor ); + free( ( post_ops->pre_ops )->b_scl ); + } + free( post_ops->pre_ops ); + } + + free( post_ops->seq_vector ); + free( post_ops ); +} + +#endif //LPGEMM_BENCH_UTILS_H diff --git a/bench/bench_aocl_gemm/bench_lpgemm_utils.c b/bench/bench_aocl_gemm/bench_lpgemm_utils.c index 8ce8104df5..02c0c23769 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm_utils.c +++ b/bench/bench_aocl_gemm/bench_lpgemm_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,7 +37,6 @@ #include #include #include -#include #include #include "blis.h" @@ -89,31 +88,27 @@ void gelu_bench_driver_ ## GELU_SFX \ inc_t incx \ ) \ { \ - double min_time_diff = DBL_MAX; \ + double dtime; \ + double dtime_save = DBL_MAX; \ for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ { \ - struct timespec tstart={0,0}, tend={0,0}; \ - clock_gettime(CLOCK_MONOTONIC, &tstart); \ + dtime = bli_clock(); \ \ if ( bench_mode == 'a' ) \ { \ GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx) ); \ } \ \ - GEN_FUNC_NAME(aocl_,GELU_SFX) \ + GEN_FUNC_NAME(aocl_gemm_,GELU_SFX) \ ( \ n, x, incx \ ); \ \ - clock_gettime(CLOCK_MONOTONIC, &tend); \ + dtime_save = bli_clock_min_diff( dtime_save, dtime ); \ \ - double diff = \ - ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ - ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ - min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ } \ \ - print_result( XSTR(GELU_SFX), n_repeats, n, incx, min_time_diff); \ + print_result( XSTR(GELU_SFX), n_repeats, n, incx, dtime_save); \ } \ GEN_GELU_BENCH_DRV_FN(float,gelu_tanh_f32) @@ -128,31 +123,26 @@ void softmax_bench_driver_ ## SOFTMAX_SFX \ inc_t incx \ ) \ { \ - double min_time_diff = DBL_MAX; \ + double dtime; \ + double dtime_save = DBL_MAX; \ for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ { \ - struct timespec tstart={0,0}, tend={0,0}; \ - clock_gettime(CLOCK_MONOTONIC, &tstart); \ + dtime = bli_clock(); \ \ if ( bench_mode == 'a' ) \ { \ GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx) ); \ } \ \ - GEN_FUNC_NAME(aocl_,SOFTMAX_SFX) \ + GEN_FUNC_NAME(aocl_gemm_,SOFTMAX_SFX) \ ( \ n, x, incx \ ); \ \ - clock_gettime(CLOCK_MONOTONIC, &tend); \ - \ - double diff = \ - ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ - ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ - min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + dtime_save = bli_clock_min_diff( dtime_save, dtime ); \ } \ \ - print_result( XSTR(SOFTMAX_SFX), n_repeats, n, incx, min_time_diff); \ + print_result( XSTR(SOFTMAX_SFX), n_repeats, n, incx, dtime_save); \ } \ GEN_SOFTMAX_BENCH_DRV_FN(float,softmax_f32) @@ -323,22 +313,26 @@ int main( int argc, char** argv ) } char* file_name = NULL; - - // Parse CLI arguments. - opterr = 0; - int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + getopt_t state; + // Initialize the state for running bli_getopt(). Here, 0 is the + // initial value for opterr, which suppresses error messages. + bli_getopt_init_state( 0, &state ); + + int opt; + // Process all option arguments until we get a -1, which means we're done. + while( (opt = bli_getopt( argc, argv, "i:m:n:", &state )) != -1 ) { - switch ( opt_val ) + char opt_ch = ( char )opt; + switch( opt_ch ) { case 'i': - file_name = optarg; + file_name = state.optarg; break; case 'm': - bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + bench_mode = ( ( ( *state.optarg ) == 'a' ) || ( ( *state.optarg ) == 'p' ) ) ? ( *state.optarg ) : 'p'; break; case 'n': - global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + global_n_repeat = ( atoi( state.optarg ) > 0 ) ? atoi( state.optarg ) : 0; break; default: break; diff --git a/bench/bench_aocl_gemm/data_gen_lpgemm.py b/bench/bench_aocl_gemm/data_gen_lpgemm.py index 3bc3a24421..563ccaf57d 100644 --- a/bench/bench_aocl_gemm/data_gen_lpgemm.py +++ b/bench/bench_aocl_gemm/data_gen_lpgemm.py @@ -34,19 +34,19 @@ # Initializing global mnk_array.This array will be used to store all mnk values mnk_array = [] -max_elem = 2500; +max_elem = 2600; out_file_name = "accuracy_test_data_lpgemm.txt" # Important mnk generator function.This will generate all possible combinations # of m,n,k values using formula m(t+1)=ROUND(m(t)*Base,0)+offset def mnk_generator(): k_1 = 1 - incr_k = 20 + incr_k = 500 while (k_1 <= max_elem): n_1 = 1 - incr_n = 20 + incr_n = 200 while (n_1 <= max_elem): m_1 = 1 - incr_m = 20 + incr_m = 100 while (m_1 <= max_elem): mnk_array.append([m_1, n_1, k_1]) if (m_1 == 1): @@ -68,8 +68,8 @@ def data_gen(): fout = open(out_file_name, "w") for ele in mnk_array: - fout.write("i r " + str(ele[0]) + " " + str(ele[1]) + " " + str(ele[2]) + " " +\ - str(ele[2]) + " " + str(ele[1]) + " " + str(ele[1]) + "\n") + fout.write("r n n n r " + str(ele[0]) + " " + str(ele[1]) + " " + str(ele[2]) + " " +\ + str(ele[2]) + " " + str(ele[1]) + " " + str(ele[1]) + " u8s8s32os32:none" + "\n") fout.truncate(fout.tell() - 1) fout.close() diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c index db62ead33e..4dfa86666b 100644 --- a/bench/bench_axpbyv.c +++ b/bench/bench_axpbyv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -86,7 +86,7 @@ int main( int argc, char** argv ) #ifdef DEBUG fprintf( fout, "gflops\n" ); #else - fprintf(fout, "Dt\t n\t alpha_r\t alpha_i\t beta_r\t beta_i\t gflops\n" ); + fprintf(fout, "Func Dt n alpha_r alpha_i incx beta_r beta_i incy gflops\n" ); #endif dim_t n; // dimension @@ -253,8 +253,8 @@ int main( int argc, char** argv ) (unsigned long)n, gflops ); - fprintf( fout, "%c\t %ld\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", - dt_ch, n, alpha_r, alpha_i, beta_r, beta_i, gflops ); + fprintf( fout, "%s %c %ld %lf %lf %ld %lf %lf %ld %6.3f\n", + tmp, dt_ch, n, alpha_r, alpha_i, incx, beta_r, beta_i, incy, gflops ); fflush( fout ); bli_obj_free( &x ); diff --git a/bench/bench_axpyv.c b/bench/bench_axpyv.c new file mode 100644 index 0000000000..03e2d64f85 --- /dev/null +++ b/bench/bench_axpyv.c @@ -0,0 +1,258 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +#ifndef DT +#define DT BLIS_DOUBLE +#endif +#define AOCL_MATRIX_INITIALISATION + +int main( int argc, char** argv ) +{ + obj_t x, y, y_save, alpha; // BLIS objects + dim_t p_inc = 0; // To keep track of number of inputs + num_t dt; // BLIS datatype + char dt_ch; // {S, D, Z, C} from input + int r, n_repeats; // repetition counter; number of repeats + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; // Input FILE* + FILE* fout = NULL; // Output FILE* + + n_repeats = N_REPEAT; // Fetched from Makefile + + dt = DT; // Set datatype as BLIS_DOUBLE + + if ( argc < 3 ) + { + printf( "Usage: ./bench_axpyv_XX.x input.txt output.txt\n" ); + exit( 1 ); + } + + fin = fopen( argv[1], "r" ); // Open input file in read mode + if ( fin == NULL ) + { + printf( "Error opening input file %s\n", argv[1] ); + exit( 1 ); + } + + fout = fopen( argv[2], "w" ); // Open output file in write mode + if ( fout == NULL ) + { + printf( "Error opening output file %s\n", argv[2] ); + exit( 1 ); + } + +#ifdef DEBUG + fprintf( fout, "gflops\n" ); +#else + fprintf(fout, "Func Dt n alphaR alphaI incx incy gflops\n" ); +#endif + + dim_t n; // dimension + inc_t incx; // stride x + inc_t incy; // stride y + char tmp[256]; // to store function name, line not present in logs + double alpha_r, alpha_i; + + // {function name} {S, D, C, Z} {n} + // {alpha_r} {alpha_i} {incx} {incy} + while ( fscanf( fin, "%s %c " INT_FS " %lf %lf " INT_FS INT_FS "\n", + tmp, &dt_ch, &n, + &alpha_r, &alpha_i, &incx, &incy ) == 7 ) + { + if ( dt_ch == 'D' || dt_ch == 'd' ) dt = BLIS_DOUBLE; + else if ( dt_ch == 'Z' || dt_ch == 'z' ) dt = BLIS_DCOMPLEX; + else if ( dt_ch == 'S' || dt_ch == 's' ) dt = BLIS_FLOAT; + else if ( dt_ch == 'C' || dt_ch == 'c' ) dt = BLIS_SCOMPLEX; + else + { + printf( "Invalid data type %c\n", dt_ch ); + continue; + } + + // Creating BLIS objects + bli_obj_create( dt, n, 1, incx, 1, &x ); // For input vector x + bli_obj_create( dt, n, 1, incy, 1, &y ); // For output vector y + bli_obj_create( dt, n, 1, incy, 1, &y_save ); // For vector y_save + bli_obj_create( dt, 1, 1, 0, 0, &alpha); // For input scalar alpha + + #ifdef AOCL_MATRIX_INITIALISATION + bli_randm( &x ); + bli_randm( &y ); + #endif + + // Copying contents of y to y_save + bli_copyv( &y, &y_save ); + + bli_setsc( alpha_r, alpha_i, &alpha ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + // Copying contents of y_save to y + bli_copyv( &y_save, &y ); + + dtime = bli_clock(); + +#ifdef BLIS + bli_axpyv( &alpha, &x, &y ); +#else + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + f77_int blas_incy = bli_obj_vector_inc( &y ); + + if ( bli_is_float( dt ) ) + { + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_saxpy( nn, + *alphap, + xp, + blas_incx, + yp, + blas_incy ); +#else + saxpy_( &nn, + alphap, + xp, + &blas_incx, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_double( dt ) ) + { + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_daxpy( nn, + *alphap, + xp, + blas_incx, + yp, + blas_incy ); +#else + daxpy_( &nn, + alphap, + xp, + &blas_incx, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_caxpy( nn, + *alphap, + xp, + blas_incx, + yp, + blas_incy ); +#else + caxpy_( &nn, + alphap, + xp, + &blas_incx, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_zaxpy( nn, + *alphap, + xp, + blas_incx, + yp, + blas_incy ); +#else + zaxpy_( &nn, + alphap, + xp, + &blas_incx, + yp, + &blas_incy ); +#endif + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_axpyv_%s", BLAS ); + + p_inc++; + printf( " %4lu [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops ); + + fprintf( fout, "%s %c %ld %lf %lf %ld %ld %6.3f\n", + tmp, dt_ch, n, alpha_r, alpha_i, incx, incy, gflops ); + fflush( fout ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + return 0; +} diff --git a/bench/bench_copyv.c b/bench/bench_copyv.c index 1e7f20e647..24d8cbc8c1 100644 --- a/bench/bench_copyv.c +++ b/bench/bench_copyv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -94,7 +94,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t n\t incx\t incy\t gflops\n"); + fprintf(fout, "Func Dt n incx incy gflops\n"); char tmp[256]; // to store function name, line no present in logs. dim_t n; diff --git a/bench/bench_dotv.c b/bench/bench_dotv.c index 9ca0cd386d..502834b315 100644 --- a/bench/bench_dotv.c +++ b/bench/bench_dotv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -45,7 +45,6 @@ #define DT BLIS_DOUBLE #endif - #define AOCL_MATRIX_INITIALISATION //#define BLIS_ENABLE_CBLAS @@ -63,7 +62,7 @@ int main( int argc, char** argv ) obj_t x, y, res; dim_t p_inc = 0; // to keep track of number of inputs num_t dt; - char dt_ch; + char dt_ch, conjx_ch; int r, n_repeats; double dtime; @@ -95,22 +94,23 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t n\t incx\t incy\t gflops\n"); + fprintf(fout, "Func Dt trans n incx incy gflops\n"); dim_t n; inc_t incx; inc_t incy; + conj_t conjx; char tmp[256]; // to store function name, line no present in logs. - // {S,D,C,Z} {n incx incy} - while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", - tmp, &dt_ch, &n, &incx, &incy) == 5) + // {S,D,C,Z} {conjx n incx incy} + while (fscanf(fin, "%s %c %c " INT_FS INT_FS INT_FS "\n", + tmp, &dt_ch, &conjx_ch, &n, &incx, &incy) == 6) { #ifdef PRINT - fprintf (stdout, "Input = %s %c %ld %ld %ld %6.3f\n", - tmp, dt_ch, n, incx, incy, gflops); + fprintf (stdout, "Input = %s %c %c %ld %ld %ld %6.3f\n", + tmp, dt_ch, conjx_ch, n, incx, incy, gflops); #endif if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; @@ -123,6 +123,14 @@ int main( int argc, char** argv ) continue; } + if ( conjx_ch == 'C' || conjx_ch == 'c' ) conjx = BLIS_CONJUGATE; + else if ( conjx_ch == 'N' || conjx_ch == 'n' ) conjx = BLIS_NO_CONJUGATE; + else + { + printf("Invalid conjugate value %c\n", conjx_ch); + continue; + } + // Create objects with required sizes and strides. // // The ?dot routines perform a vector-vector reduction operation defined as @@ -196,34 +204,61 @@ int main( int argc, char** argv ) yp, &incy ); #endif } - else if ( bli_is_scomplex( dt ) ) + else if ( bli_is_scomplex( dt ) && !bli_is_conj( conjx ) ) + { + scomplex* xp = bli_obj_buffer( &x ); + scomplex* yp = bli_obj_buffer( &y ); + scomplex* resp = bli_obj_buffer( &res ); + +#ifdef CBLAS + cblas_cdotu_sub( nn, + xp, incx, + yp, incy, resp ); +#else + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *resp = cdotu_( &nn, + xp, &incx, + yp, &incy ); + +#else + cdotu_( resp, &nn, + xp, &incx, + yp, &incy ); + + +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL + +#endif + } + else if ( bli_is_scomplex( dt ) && bli_is_conj( conjx ) ) { scomplex* xp = bli_obj_buffer( &x ); scomplex* yp = bli_obj_buffer( &y ); scomplex* resp = bli_obj_buffer( &res ); #ifdef CBLAS - cblas_cdotu_sub(nn, - xp, incx, - yp, incy, resp ); + cblas_cdotc_sub( nn, + xp, incx, + yp, incy, resp ); #else #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL - *resp = cdotu_(&nn, - xp, &incx, - yp, &incy ); + *resp = cdotc_( &nn, + xp, &incx, + yp, &incy ); #else - cdotu_(resp, &nn, - xp, &incx, - yp, &incy ); + cdotc_( resp, &nn, + xp, &incx, + yp, &incy ); -#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL ... +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif } - else if ( bli_is_dcomplex( dt ) ) + else if ( bli_is_dcomplex( dt ) && !bli_is_conj( conjx ) ) { dcomplex* xp = bli_obj_buffer( &x ); dcomplex* yp = bli_obj_buffer( &y ); @@ -242,19 +277,47 @@ int main( int argc, char** argv ) #else zdotu_( resp, &nn, + xp, &incx, + yp, &incy ); + + +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL + +#endif + } + else if ( bli_is_dcomplex( dt ) && bli_is_conj( conjx ) ) + { + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* yp = bli_obj_buffer( &y ); + dcomplex* resp = bli_obj_buffer( &res ); + +#ifdef CBLAS + cblas_zdotc_sub( nn, + xp, incx, + yp, incy, resp ); +#else + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *resp = zdotc_( &nn, xp, &incx, yp, &incy ); +#else + zdotc_( resp, &nn, + xp, &incx, + yp, &incy ); + #endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif + } #endif // BLIS Interface #ifdef PRINT - bli_printm( "a after", &a, "%4.1f", "" ); + bli_printm( "res", &res, "%4.1f", "" ); exit(1); #endif @@ -272,7 +335,7 @@ int main( int argc, char** argv ) (unsigned long)n, gflops); - fprintf (fout, "%s %c %ld %ld %ld %6.3f\n", tmp, dt_ch, n, incx, incy, gflops); + fprintf (fout, "%s %c %c %ld %ld %ld %6.3f\n", tmp, dt_ch, conjx_ch, n, incx, incy, gflops); fflush(fout); diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c old mode 100755 new mode 100644 index 454b8b0bc0..8ac6f83953 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -15,7 +15,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -114,7 +114,7 @@ int main( int argc, char** argv ) n_repeats = atoi(argv[3]); } - fprintf(fout, "Dt transa transb m n k alphaR alphaI lda ldb betaR betaI ldc gflops\n"); + fprintf(fout, "Func Dt transa transb m n k alphaR alphaI lda ldb betaR betaI ldc gflops\n"); // Following variables are needed for scanf to read inputs properly // however they are not used in bench. @@ -240,7 +240,7 @@ int main( int argc, char** argv ) #ifdef AOCL_MATRIX_INITIALISATION bli_randm( &a ); bli_randm( &b ); - bli_randm( &c ); + bli_randm( &c_save ); #endif bli_obj_set_conjtrans( transa, &a); @@ -249,7 +249,7 @@ int main( int argc, char** argv ) bli_setsc( alpha_r, alpha_i, &alpha ); bli_setsc( beta_r, beta_i, &beta ); - bli_copym( &c, &c_save ); + // bli_copym( &c, &c_save ); dtime_save = DBL_MAX; @@ -482,8 +482,8 @@ int main( int argc, char** argv ) (unsigned long)n, (unsigned long)k, gflops); - fprintf (fout, "%c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld %6.3f\n", \ - dt_ch, transA_c, transB_c, m, n, k, alpha_r, alpha_i, lda, ldb, beta_r, beta_i, ldc, gflops); + fprintf (fout, "%s %c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld %6.3f\n", \ + api_name, dt_ch, transA_c, transB_c, m, n, k, alpha_r, alpha_i, lda, ldb, beta_r, beta_i, ldc, gflops); fflush(fout); diff --git a/bench/bench_gemm_pack_compute.c b/bench/bench_gemm_pack_compute.c old mode 100755 new mode 100644 index 30236ee859..22f8e9ba78 --- a/bench/bench_gemm_pack_compute.c +++ b/bench/bench_gemm_pack_compute.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/bench/bench_gemmt.c b/bench/bench_gemmt.c index cd2e5bf9b8..a2ddef1a13 100644 --- a/bench/bench_gemmt.c +++ b/bench/bench_gemmt.c @@ -4,8 +4,9 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright @@ -13,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -107,7 +108,7 @@ int main( int argc, char** argv ) printf("Error opening output file %s\n", argv[2]); exit(1); } - fprintf(fout, "Dt\t uplo\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t alphaR\t alphaI\t betaR\t betaI\t gflops\n"); + fprintf(fout, "Func Dt uplo n k lda ldb ldc transa transb alphaR alphaI betaR betaI gflops\n"); inc_t lda; @@ -463,8 +464,8 @@ int main( int argc, char** argv ) ( unsigned long )n, ( unsigned long )k, gflops ); - fprintf(fout, "%c\t %c\t %ld\t %ld\t %ld\t %ld\t %ld\t %c\t %c\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", \ - dt_ch, uplo_c, n, k, lda, ldb, ldc, + fprintf(fout, "%s %c %c %ld %ld %ld %ld %ld %c %c %lf %lf %lf %lf %6.3f\n", \ + tmp, dt_ch, uplo_c, n, k, lda, ldb, ldc, transA_c, transB_c, alpha_r, alpha_i, diff --git a/bench/bench_gemv.c b/bench/bench_gemv.c old mode 100755 new mode 100644 index dd77a0539c..e8e9f121ca --- a/bench/bench_gemv.c +++ b/bench/bench_gemv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -99,7 +99,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt transa\t m\t n\t alpha\t lda\t incx\t beta\t incy\t gflops\n"); + fprintf(fout, "Func Dt transa m n alphaR alphaI lda incx betaR betaI incy gflops\n"); char transA; dim_t m; diff --git a/bench/bench_ger.c b/bench/bench_ger.c index b4ee38a799..537ed016cb 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -101,7 +101,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t m\t n\t alpha\t incx\t incy\t lda\t gflops\n"); + fprintf(fout, "Func Dt m n alphaR alphaI incx incy lda gflops\n"); dim_t m; dim_t n; diff --git a/bench/bench_nrm2.c b/bench/bench_nrm2.c index ae79eb3307..20fd140b4e 100644 --- a/bench/bench_nrm2.c +++ b/bench/bench_nrm2.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -100,7 +100,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t n\t incx\t gflops\n"); + fprintf(fout, "Func Dt n incx gflops\n"); dim_t n; inc_t incx; char tmp[256]; // to store function name, line no present in logs. @@ -225,8 +225,8 @@ int main( int argc, char** argv ) (unsigned long)n, gflops); - fprintf (fout, "%c %ld %ld %6.3f\n", - dt_ch, n, incx, gflops); + fprintf (fout, "%s %c %ld %ld %6.3f\n", + tmp, dt_ch, n, incx, gflops); fflush(fout); diff --git a/bench/bench_scalv.c b/bench/bench_scalv.c index 80b3762ea2..d3ce99718c 100644 --- a/bench/bench_scalv.c +++ b/bench/bench_scalv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -63,8 +63,8 @@ int main( int argc, char** argv ) obj_t x, x_save; obj_t alpha; dim_t p_inc = 0; // to keep track of number of inputs - num_t dt; - char dt_ch; + num_t dt_x, dt_alpha; + char dt_ch_x, dt_ch_alpha; int r, n_repeats; double dtime; @@ -76,7 +76,8 @@ int main( int argc, char** argv ) n_repeats = N_REPEAT; // This macro will get from Makefile. - dt = DT; + dt_x = DT; + dt_alpha = DT; if (argc < 3) { @@ -96,31 +97,44 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t alpha\t n\t incx\t gflops\n"); + fprintf(fout, "Func Dt alphaR alphaI n incx gflops\n"); dim_t n; double alpha_r, alpha_i; inc_t incx; + char dt_ch[3]; // to store the API datatype char tmp[256]; // to store function name, line no present in logs. - // {S,D,C,Z} {alpha n incx} - while (fscanf(fin, "%s %c %lf %lf " INT_FS INT_FS "\n", - tmp, &dt_ch, &alpha_r, &alpha_i, &n, &incx) == 6) + while (fscanf(fin, "%s %s %lf %lf " INT_FS INT_FS "\n", + tmp, dt_ch, &alpha_r, &alpha_i, &n, &incx) == 6) { + dt_ch[2] = '\0'; // Null terminating the string for logging purpose #ifdef PRINT - fprintf (stdout, "Input = %s %c %lf %lf %ld %ld\n", + fprintf (stdout, "Input = %s %s %lf %lf %ld %ld\n", tmp, dt_ch, alpha_r, alpha_i, n, incx); #endif - - if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; - else if (dt_ch == 'Z' || dt_ch == 'z') dt = BLIS_DCOMPLEX; - else if (dt_ch == 'S' || dt_ch == 's') dt = BLIS_FLOAT; - else if (dt_ch == 'C' || dt_ch == 'c') dt = BLIS_SCOMPLEX; + // Acquiring the datatype of input vector x + dt_ch_x = dt_ch[0]; + if (dt_ch_x == 'D' || dt_ch_x == 'd') dt_x = BLIS_DOUBLE; + else if (dt_ch_x == 'Z' || dt_ch_x == 'z') dt_x = BLIS_DCOMPLEX; + else if (dt_ch_x == 'S' || dt_ch_x == 's') dt_x = BLIS_FLOAT; + else if (dt_ch_x == 'C' || dt_ch_x == 'c') dt_x = BLIS_SCOMPLEX; + else + { + printf("Invalid data type %c\n", dt_ch_x); + continue; + } + + // Acquiring the datatype of input scalar alpha + dt_ch_alpha = dt_ch[1]; + if (dt_ch_alpha == 'D' || dt_ch_alpha == 'd') dt_alpha = BLIS_DOUBLE; + else if (dt_ch_alpha == 'S' || dt_ch_alpha == 's') dt_alpha = BLIS_FLOAT; + else if(dt_ch_alpha == '\0') dt_alpha = dt_x; else { - printf("Invalid data type %c\n", dt_ch); + printf("Invalid data type %c\n", dt_ch_alpha); continue; } @@ -135,14 +149,14 @@ int main( int argc, char** argv ) // a is a scalar // X is an n-element vector. - bli_obj_create( dt, n, 1, incx, 1, &x ); - bli_obj_create( dt, n, 1, incx, 1, &x_save ); + bli_obj_create( dt_x, n, 1, incx, 1, &x ); + bli_obj_create( dt_x, n, 1, incx, 1, &x_save ); #ifdef AOCL_MATRIX_INITIALISATION bli_randm( &x ); #endif - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); bli_setsc( alpha_r, alpha_i, &alpha ); bli_copym( &x, &x_save ); @@ -168,19 +182,19 @@ int main( int argc, char** argv ) f77_int nn = bli_obj_length( &x ); f77_int blas_incx = bli_obj_vector_inc( &x ); - if ( bli_is_float( dt ) ){ + if ( bli_is_float( dt_x ) && bli_is_float( dt_alpha ) ){ float* xp = bli_obj_buffer( &x ); float* scalar = bli_obj_buffer( &alpha ); #ifdef CBLAS cblas_sscal( nn, *scalar, xp, blas_incx ); -#else // cblas scal +#else // cblas sscal sscal_( &nn, scalar, xp, &blas_incx ); -#endif // cblas scal +#endif // cblas sscal } - else if ( bli_is_double( dt ) ) + else if ( bli_is_double( dt_x ) && bli_is_double( dt_alpha ) ) { double* xp = bli_obj_buffer( &x ); @@ -195,7 +209,7 @@ int main( int argc, char** argv ) xp, &blas_incx ); #endif // cblas dscal } - else if ( bli_is_scomplex( dt ) ) + else if ( bli_is_scomplex( dt_x ) && bli_is_scomplex( dt_alpha ) ) { scomplex* xp = bli_obj_buffer( &x ); scomplex* scalar = bli_obj_buffer( &alpha ); @@ -209,7 +223,7 @@ int main( int argc, char** argv ) xp, &blas_incx ); #endif // cblas cscal } - else if ( bli_is_dcomplex( dt ) ) + else if ( bli_is_dcomplex( dt_x ) && bli_is_dcomplex( dt_alpha ) ) { dcomplex* xp = bli_obj_buffer( &x ); dcomplex* scalar = bli_obj_buffer( &alpha ); @@ -220,7 +234,33 @@ int main( int argc, char** argv ) #else // cblas zscal zscal_( &nn, scalar, xp, &blas_incx ); -#endif // cblas zcscal +#endif // cblas zscal + } + else if ( bli_is_scomplex( dt_x ) && bli_is_float( dt_alpha ) ) + { + scomplex* xp = bli_obj_buffer( &x ); + float* scalar = bli_obj_buffer( &alpha ); +#ifdef CBLAS + cblas_csscal( nn, + *scalar, + xp, blas_incx ); +#else // cblas csscal + csscal_( &nn, scalar, + xp, &blas_incx ); +#endif // cblas csscal + } + else if ( bli_is_dcomplex( dt_x ) && bli_is_double( dt_alpha ) ) + { + dcomplex* xp = bli_obj_buffer( &x ); + double* scalar = bli_obj_buffer( &alpha ); +#ifdef CBLAS + cblas_zdscal( nn, + *scalar, + xp, blas_incx ); +#else // cblas zdscal + zdscal_( &nn, scalar, + xp, &blas_incx ); +#endif // cblas zdscal } #endif // BLIS Interface @@ -235,7 +275,11 @@ int main( int argc, char** argv ) gflops = n / ( dtime_save * 1.0e9 ); - if ( bli_is_complex( dt ) ) gflops *= 4.0; + if ( bli_is_complex( dt_x ) ) + { + if( bli_is_complex( dt_alpha ) ) gflops *= 4.0; + else if( bli_is_real( dt_alpha ) ) gflops *= 2.0; + } printf( "data_scalv_%s", BLAS ); @@ -245,7 +289,7 @@ int main( int argc, char** argv ) (unsigned long)n, gflops); - fprintf (fout, "%s %c %lf %lf %ld %ld %6.3f\n", + fprintf (fout, "%s %s %lf %lf %ld %ld %6.3f\n", tmp, dt_ch, alpha_r, alpha_i, n, incx, gflops); fflush(fout); diff --git a/bench/bench_swapv.c b/bench/bench_swapv.c index 3040d7b582..fe3ac5d84f 100644 --- a/bench/bench_swapv.c +++ b/bench/bench_swapv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -95,7 +95,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt\t n\t incx\t incy\t gflops\n"); + fprintf(fout, "Func Dt n incx incy gflops\n"); dim_t n; inc_t incx; diff --git a/bench/bench_syrk.c b/bench/bench_syrk.c index 5bcc20e060..b7a3ea87f2 100644 --- a/bench/bench_syrk.c +++ b/bench/bench_syrk.c @@ -4,8 +4,9 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright @@ -13,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -106,7 +107,7 @@ int main( int argc, char** argv ) printf("Error opening output file %s\n", argv[2]); exit(1); } - fprintf(fout, "Dt uploc transa n\t k\t alphaR\t alphaI\t betaR\t betaI\t lda\t ldc\t gflops\n"); + fprintf(fout, "Func Dt uploc transa n k alphaR alphaI lda betaR betaI ldc gflops\n"); inc_t lda; @@ -411,12 +412,11 @@ int main( int argc, char** argv ) ( unsigned long )n, ( unsigned long )k, gflops ); - fprintf(fout, "%c %c %c %ld\t %ld\t %lf\t %lf\t %lf\t %lf\t %lu\t %lu\t %6.3f\n", \ - dt_ch, uplo_c, transA_c, n, k, + fprintf(fout, "%s %c %c %c %ld %ld %lf %lf %lu %lf %lf %lu %6.3f\n", \ + tmp, dt_ch, uplo_c, transA_c, n, k, alpha_r, alpha_i, - beta_r, beta_i, - lda, ldc, - gflops + lda, beta_r, beta_i, + ldc, gflops ); fflush(fout); diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index 87dd677a4d..9ea4cd57f4 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -1,11 +1,11 @@ - /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -15,9 +15,10 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -97,7 +98,7 @@ int main( int argc, char** argv ) printf("Error opening the file %s\n", argv[2]); exit(1); } - fprintf(fout,"dt\t side\t uploa\t transa\t diaga\t m\t n\t lda\t ldb\t alphaR\t alphaI\t gflops\n"); + fprintf(fout,"Func dt side uploa transa diaga m n lda ldb alphaR alphaI gflops\n"); dim_t lda,ldb; f77_char dt_type_arg, side_arg, uploa_arg, transa_arg, diaga_arg; @@ -398,9 +399,9 @@ int main( int argc, char** argv ) printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", ( unsigned long )p_inc, ( unsigned long )m, gflops ); - fprintf(fout,"%c\t %c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\t %6.3f\t %6.3f\n", - dt_type_arg, side_arg, uploa_arg, transa_arg, - diaga_arg, (unsigned long )m, (unsigned long ) n, (unsigned long )lda, + fprintf(fout,"%s %c %c %c %c %c %4lu %4lu %4lu %4lu %6.3f %6.3f %6.3f\n", + logline, dt_type_arg, side_arg, uploa_arg, transa_arg, + diaga_arg, (unsigned long )m, (unsigned long )n, (unsigned long )lda, (unsigned long )ldb, alphaR, alphaI, gflops); fflush(fout); bli_obj_free( &alpha ); diff --git a/bench/bench_trsv.c b/bench/bench_trsv.c index 4714f813d4..db1812a9e4 100644 --- a/bench/bench_trsv.c +++ b/bench/bench_trsv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -14,7 +14,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -118,7 +118,7 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt uploa\t transa\t diaga\t m\t lda\t incx\t gflops\n"); + fprintf(fout, "Func Dt uploa transa diaga m lda incx gflops\n"); // {S,D,C,Z} {uploa transa diaga m lda, incx} while (fscanf(fin, "%s %c %c %c %c " INT_FS INT_FS INT_FS "\n", @@ -383,7 +383,7 @@ int main( int argc, char** argv ) ( unsigned long )p_inc, ( unsigned long )m, gflops ); - fprintf (fout, "%s\t %c\t %c\t %c\t %c\t %ld\t %ld\t %ld\t %6.3f\n", + fprintf (fout, "%s %c %c %c %c %ld %ld %ld %6.3f\n", tmp, dt_ch, uploa_c, transA, diaga_c, m, lda, incx, gflops); fflush(fout); diff --git a/bench/inputaxpyv.txt b/bench/inputaxpyv.txt new file mode 100644 index 0000000000..173b6e496c --- /dev/null +++ b/bench/inputaxpyv.txt @@ -0,0 +1,40 @@ +saxpyv_ S 32 0.900000 0.000000 1 1 +saxpyv_ S 64 1.000000 0.000000 1 1 +saxpyv_ S 100 -1 0.000000 1 1 +saxpyv_ S 200 -1.100000 0.000000 1 1 +saxpyv_ S 300 1.100000 0.000000 1 1 +saxpyv_ S 400 0.900000 0.000000 1 1 +saxpyv_ S 500 1.000000 0.000000 1 1 +saxpyv_ S 1000 -1 0.000000 1 1 +saxpyv_ S 5000 -1.100000 0.000000 1 1 +saxpyv_ S 10000 1.100000 0.000000 1 1 +daxpyv_ D 32 0.900000 0.000000 1 1 +daxpyv_ D 64 1.000000 0.000000 1 1 +daxpyv_ D 100 -1 0.000000 1 1 +daxpyv_ D 200 -1.100000 0.000000 1 1 +daxpyv_ D 300 1.100000 0.000000 1 1 +daxpyv_ D 400 0.900000 0.000000 1 1 +daxpyv_ D 500 1.000000 0.000000 1 1 +daxpyv_ D 1000 -1 0.000000 1 1 +daxpyv_ D 5000 -1.100000 0.000000 1 1 +daxpyv_ D 10000 1.100000 0.000000 1 1 +caxpyv_ C 32 0.900000 -1.100000 1 1 +caxpyv_ C 64 1.000000 1.100000 1 1 +caxpyv_ C 100 -1 1.000000 1 1 +caxpyv_ C 200 -1.100000 0.900000 1 1 +caxpyv_ C 300 1.100000 1.000000 1 1 +caxpyv_ C 400 0.900000 -1.100000 1 1 +caxpyv_ C 500 1.000000 1.000000 1 1 +caxpyv_ C 1000 -1 0.900000 1 1 +caxpyv_ C 5000 -1.100000 -1 1 1 +caxpyv_ C 10000 1.100000 -1 1 1 +zaxpyv_ Z 32 0.900000 -1.100000 1 1 +zaxpyv_ Z 64 1.000000 1.100000 1 1 +zaxpyv_ Z 100 -1 1.000000 1 1 +zaxpyv_ Z 200 -1.100000 0.900000 1 1 +zaxpyv_ Z 300 1.100000 1.000000 1 1 +zaxpyv_ Z 400 0.900000 -1.100000 1 1 +zaxpyv_ Z 500 1.000000 1.000000 1 1 +zaxpyv_ Z 1000 -1 0.900000 1 1 +zaxpyv_ Z 5000 -1.100000 -1 1 1 +zaxpyv_ Z 10000 1.100000 -1 1 1 diff --git a/bench/inputdotv.txt b/bench/inputdotv.txt index 16e761de86..7f13f935d2 100644 --- a/bench/inputdotv.txt +++ b/bench/inputdotv.txt @@ -1,32 +1,56 @@ -ddot_:183: D 0 100 1 -ddot_:183: D 1 100 1 -ddot_:183: D 10 100 1 -ddot_:183: D 11 100 1 -ddot_:183: D 12 100 1 -ddot_:183: D 13 100 1 -ddot_:183: D 14 100 1 -ddot_:183: D 15 100 1 -ddot_:183: D 2 100 1 -ddot_:183: D 3 100 1 -ddot_:183: D 4 100 1 -ddot_:183: D 5 100 1 -ddot_:183: D 6 100 1 -ddot_:183: D 7 100 1 -ddot_:183: D 8 100 1 -ddot_:183: D 9 100 1 -ddot_:183: D 100 100 100 -ddot_:183: D 100 1 100 -ddot_:183: D 100 100 1 -ddot_:183: D 100 1 1 -sdot_:102: S 4000 1 1 -sdot_:102: S 4000 1 1 -sdot_:102: S 4000 1 1 -sdot_:102: S 3960 1 1 -sdot_:102: S 3960 1 1 -sdot_:102: S 3960 1 1 -sdot_:102: S 3920 1 1 -sdot_:102: S 3920 1 1 -sdot_:102: S 3920 1 1 -sdot_:102: S 3880 1 1 -sdot_:102: S 3880 1 1 -sdot_:102: S 3880 1 1 +ddot_:183: D N 0 100 1 +ddot_:183: D N 1 100 1 +ddot_:183: D N 10 100 1 +ddot_:183: D N 11 100 1 +ddot_:183: D N 12 100 1 +ddot_:183: D N 13 100 1 +ddot_:183: D N 14 100 1 +ddot_:183: D N 15 100 1 +ddot_:183: D N 2 100 1 +ddot_:183: D N 3 100 1 +ddot_:183: D N 4 100 1 +ddot_:183: D N 5 100 1 +ddot_:183: D N 6 100 1 +ddot_:183: D N 7 100 1 +ddot_:183: D N 8 100 1 +ddot_:183: D N 9 100 1 +ddot_:183: D N 100 100 100 +ddot_:183: D N 100 1 100 +ddot_:183: D N 100 100 1 +ddot_:183: D N 100 1 1 +sdot_:102: S N 4000 1 1 +sdot_:102: S N 4000 1 1 +sdot_:102: S N 4000 1 1 +sdot_:102: S N 3960 1 1 +sdot_:102: S N 3960 1 1 +sdot_:102: S N 3960 1 1 +sdot_:102: S N 3920 1 1 +sdot_:102: S N 3920 1 1 +sdot_:102: S N 3920 1 1 +sdot_:102: S N 3880 1 1 +sdot_:102: S N 3880 1 1 +sdot_:102: S N 3880 1 1 +cdot_ C N 4000 1 1 +cdot_ C N 4000 1 1 +cdot_ C N 4000 1 1 +cdot_ C N 3960 1 1 +cdot_ C N 3960 1 1 +cdot_ C N 3960 1 1 +cdot_ C C 3920 1 1 +cdot_ C C 3920 1 1 +cdot_ C C 3920 1 1 +cdot_ C C 3880 1 1 +cdot_ C C 3880 1 1 +cdot_ C C 3880 1 1 +zdot_ Z N 4000 1 1 +zdot_ Z N 4000 1 1 +zdot_ Z N 4000 1 1 +zdot_ Z N 3960 1 1 +zdot_ Z N 3960 1 1 +zdot_ Z N 3960 1 1 +zdot_ Z C 3920 1 1 +zdot_ Z C 3920 1 1 +zdot_ Z C 3920 1 1 +zdot_ Z C 3880 1 1 +zdot_ Z C 3880 1 1 +zdot_ Z C 3880 1 1 diff --git a/bench/inputscalv.txt b/bench/inputscalv.txt index 858574546c..a27c5b7924 100644 --- a/bench/inputscalv.txt +++ b/bench/inputscalv.txt @@ -1,13 +1,13 @@ -dscal_:171: D -0.147008 0.000000 8 1 -dscal_:171: D -0.180536 0.000000 5 1 -dscal_:171: D -0.194791 0.000000 30 1 -dscal_:171: D -0.248750 0.000000 24 1 -dscal_:171: D -0.263444 0.000000 7 1 -dscal_:171: D -0.264469 0.000000 13 1 -dscal_:171: D -0.288548 0.000000 22 1 -dscal_:171: D -0.314614 0.000000 9 1 -dscal_:171: D -0.349634 0.000000 14 1 -dscal_:171: D -0.403135 0.000000 23 1 +sscal_:171: S -0.147008 0.000000 8 1 +sscal_:171: S -0.180536 0.000000 5 1 +sscal_:171: S -0.194791 0.000000 30 1 +sscal_:171: S -0.248750 0.000000 24 1 +sscal_:171: S -0.263444 0.000000 7 1 +sscal_:171: S -0.264469 0.000000 13 1 +sscal_:171: S -0.288548 0.000000 22 1 +sscal_:171: S -0.314614 0.000000 9 1 +sscal_:171: S -0.349634 0.000000 14 1 +sscal_:171: S -0.403135 0.000000 23 1 dscal_:171: D -0.421537 0.000000 31 1 dscal_:171: D -0.449256 0.000000 40 1 dscal_:171: D -0.500709 0.000000 42 1 @@ -18,34 +18,43 @@ dscal_:171: D -0.550148 0.000000 25 1 dscal_:171: D -0.559501 0.000000 44 1 dscal_:171: D -0.612256 0.000000 2 1 dscal_:171: D -0.755356 0.000000 45 1 -dscal_:171: D -0.759262 0.000000 47 1 -dscal_:171: D -0.900525 0.000000 48 1 -dscal_:171: D 0.216330 0.000000 4 1 -dscal_:171: D 0.220087 0.000000 10 1 -dscal_:171: D 0.252043 0.000000 21 1 -dscal_:171: D 0.280487 0.000000 15 1 -dscal_:171: D 0.296225 0.000000 29 1 -dscal_:171: D 0.299399 0.000000 18 1 -dscal_:171: D 0.314779 0.000000 12 1 -dscal_:171: D 0.321521 0.000000 17 1 -dscal_:171: D 0.324458 0.000000 11 1 -dscal_:171: D 0.339212 0.000000 0 1 -dscal_:171: D 0.359467 0.000000 20 1 -dscal_:171: D 0.364805 0.000000 19 1 -dscal_:171: D 0.377414 0.000000 28 1 -dscal_:171: D 0.384282 0.000000 3 1 -dscal_:171: D 0.394021 0.000000 36 1 -dscal_:171: D 0.411089 0.000000 37 1 -dscal_:171: D 0.429686 0.000000 27 1 -dscal_:171: D 0.436665 0.000000 34 1 -dscal_:171: D 0.459632 0.000000 33 1 -dscal_:171: D 0.468809 0.000000 16 1 -dscal_:171: D 0.471083 0.000000 32 1 -dscal_:171: D 0.474866 0.000000 38 1 -dscal_:171: D 0.487050 0.000000 35 1 -dscal_:171: D 0.553630 0.000000 39 1 -dscal_:171: D 0.591314 0.000000 1 1 -dscal_:171: D 0.600389 0.000000 41 1 -dscal_:171: D 0.749844 0.000000 43 1 -dscal_:171: D 1.002156 0.000000 49 1 - +cscal_:171: C -0.759262 -0.759262 47 1 +cscal_:171: C -0.900525 -0.900525 48 1 +cscal_:171: C 0.216330 0.216330 4 1 +cscal_:171: C 0.220087 0.220087 10 1 +cscal_:171: C 0.252043 0.252043 21 1 +cscal_:171: C 0.280487 0.280487 15 1 +cscal_:171: C 0.296225 0.296225 29 1 +cscal_:171: C 0.299399 0.299399 18 1 +cscal_:171: C 0.314779 0.314779 12 1 +cscal_:171: C 0.321521 0.321521 17 1 +zscal_:171: Z 0.324458 0.324458 11 1 +zscal_:171: Z 0.339212 0.339212 0 1 +zscal_:171: Z 0.359467 0.359467 20 1 +zscal_:171: Z 0.364805 0.364805 19 1 +zscal_:171: Z 0.377414 0.377414 28 1 +zscal_:171: Z 0.384282 0.384282 3 1 +zscal_:171: Z 0.394021 0.394021 36 1 +zscal_:171: Z 0.411089 0.411089 37 1 +zscal_:171: Z 0.429686 0.429686 27 1 +zscal_:171: Z 0.436665 0.436665 34 1 +csscal_:171: CS 0.459632 0.000000 33 1 +csscal_:171: CS 0.468809 0.000000 16 1 +csscal_:171: CS 0.471083 0.000000 32 1 +csscal_:171: CS 0.474866 0.000000 38 1 +csscal_:171: CS 0.487050 0.000000 35 1 +csscal_:171: CS 0.553630 0.000000 39 1 +csscal_:171: CS 0.591314 0.000000 1 1 +csscal_:171: CS 0.600389 0.000000 41 1 +csscal_:171: CS 0.749844 0.000000 43 1 +csscal_:171: CS 1.002156 0.000000 49 1 +zdscal_:171: ZD 0.459632 0.000000 33 1 +zdscal_:171: ZD 0.468809 0.000000 16 1 +zdscal_:171: ZD 0.471083 0.000000 32 1 +zdscal_:171: ZD 0.474866 0.000000 38 1 +zdscal_:171: ZD 0.487050 0.000000 35 1 +zdscal_:171: ZD 0.553630 0.000000 39 1 +zdscal_:171: ZD 0.591314 0.000000 1 1 +zdscal_:171: ZD 0.600389 0.000000 41 1 +zdscal_:171: ZD 0.749844 0.000000 43 1 +zdscal_:171: ZD 1.002156 0.000000 49 1 diff --git a/blastest/CMakeLists.txt b/blastest/CMakeLists.txt index c8a653c2fa..02d99a3b4c 100644 --- a/blastest/CMakeLists.txt +++ b/blastest/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Comments: # - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. @@ -10,7 +42,7 @@ if(NOT DEFINED BLIS_INSTALL_PATH) set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) else() set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) - set(INC_PATH ${BLIS_INSTALL_PATH}/include/blis) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/${BLIS_CONFIG_FAMILY}) endif() # Include the corresponding make_defs.cmake that holds the required compiler options. @@ -30,7 +62,7 @@ target_compile_options(f2c ${CMISCFLAGS} ${CLANGFLAGS} # Suppress warnings about uninitialized functions - -Wno-maybe-uninitialized -Wno-parentheses -Wfatal-errors + -Wno-uninitialized -Wno-parentheses -Wfatal-errors ) target_compile_definitions(f2c PRIVATE @@ -49,7 +81,11 @@ target_include_directories(f2c ) target_link_libraries(f2c PRIVATE ${LDFLAGS}) if(THREADING_MODEL STREQUAL "openmp") - target_link_libraries(f2c PRIVATE OpenMP::OpenMP_C) + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(f2c PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(f2c PRIVATE OpenMP::OpenMP_C) + endif() endif() # Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. set_target_properties(f2c PROPERTIES FOLDER blastest-targets) @@ -74,7 +110,7 @@ foreach(source ${blastest_sources}) ${CMISCFLAGS} ${CLANGFLAGS} # Suppress warnings about uninitialized functions - -Wno-parentheses -Wno-maybe-uninitialized + -Wno-parentheses -Wno-uninitialized ) target_compile_definitions(${exec_name}.x PRIVATE @@ -91,9 +127,13 @@ foreach(source ${blastest_sources}) # and the path to blis.h ${INC_PATH} ) - target_link_libraries(${exec_name}.x PRIVATE f2c libblis ${LDFLAGS}) + target_link_libraries(${exec_name}.x PRIVATE f2c ${libblis_link} ${LDFLAGS}) if(THREADING_MODEL STREQUAL "openmp") - target_link_libraries(${exec_name}.x PRIVATE OpenMP::OpenMP_C) + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(${exec_name}.x PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(${exec_name}.x PRIVATE OpenMP::OpenMP_C) + endif() endif() set_target_properties(${exec_name}.x PROPERTIES CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. @@ -105,7 +145,7 @@ foreach(source ${blastest_sources}) COMMENT "Running ${exec_name}.x with output redirected to out.${exec_name}" DEPENDS ${exec_name}.x BYPRODUCTS ${CMAKE_BINARY_DIR}/out.${exec_name} - WORKING_DIRECTORY $ + WORKING_DIRECTORY $ VERBATIM ) else()# name has 2 or 3 @@ -114,7 +154,7 @@ foreach(source ${blastest_sources}) COMMENT "Running ${exec_name}.x with input ${CMAKE_CURRENT_SOURCE_DIR}/input/${exec_name}.in and output saved to out.${exec_name}" DEPENDS ${exec_name}.x BYPRODUCTS ${CMAKE_BINARY_DIR}/out.${exec_name} - WORKING_DIRECTORY $ + WORKING_DIRECTORY $ VERBATIM ) endif() @@ -123,11 +163,22 @@ foreach(source ${blastest_sources}) list(APPEND test_executables "run-${exec_name}") endforeach() -add_custom_target(testblas DEPENDS ${test_executables}) -add_custom_target(checkblas - COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/check-blastest.py "." - DEPENDS testblas - WORKING_DIRECTORY $ - ) +if(WIN32 AND BUILD_SHARED_LIBS) + add_custom_target(testblas + DEPENDS ${libblis_link} + COMMENT "`testblas` target is not available on Windows for shared builds of BLIS. ${DETAILED_BLATEST_MESSAGE}" + ) + add_custom_target(checkblas + DEPENDS testblas + COMMENT "`checkblas` target is not available on Windows for shared builds of BLIS. ${DETAILED_BLATEST_MESSAGE}" + ) +else() + add_custom_target(testblas DEPENDS ${test_executables}) + add_custom_target(checkblas + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/check-blastest.py "." + DEPENDS testblas + WORKING_DIRECTORY $ + ) +endif() # Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. -set_target_properties(testblas checkblas PROPERTIES FOLDER blastest-targets) \ No newline at end of file +set_target_properties(testblas checkblas PROPERTIES FOLDER blastest-targets) diff --git a/blastest/f2c/arith.h b/blastest/f2c/arith.h index 11a071d511..8beaabfda1 100644 --- a/blastest/f2c/arith.h +++ b/blastest/f2c/arith.h @@ -27,10 +27,10 @@ use or performance of this software. #include #include -#ifdef _MSC_VER -#define isnan _isnan -#define isinf(x) (!_finite(x)) -#endif + + + + #ifndef isnan # define isnan(x) \ diff --git a/blastest/f2c/f2c.h b/blastest/f2c/f2c.h index fdebec8afd..48575e6e0c 100644 --- a/blastest/f2c/f2c.h +++ b/blastest/f2c/f2c.h @@ -33,11 +33,7 @@ use or performance of this software. #include #include -#ifdef _MSC_VER -# include -#else -# include -#endif +#include #ifdef __cplusplus extern "C" { @@ -161,10 +157,12 @@ struct Namelist { }; typedef struct Namelist Namelist; -#define abs(x) ((x) >= 0 ? (x) : -(x)) -#define dabs(x) (doublereal)abs(x) +#ifndef _MSC_VER #define min(a,b) ((a) <= (b) ? (a) : (b)) #define max(a,b) ((a) >= (b) ? (a) : (b)) +#endif +#define abs(x) ((x) >= 0 ? (x) : -(x)) +#define dabs(x) (doublereal)abs(x) #define dmin(a,b) (doublereal)min(a,b) #define dmax(a,b) (doublereal)max(a,b) #define bit_test(a,b) ((a) >> (b) & 1) diff --git a/blastest/src/cblat2.c b/blastest/src/cblat2.c index 2916a36a4e..c18ffe0b70 100644 --- a/blastest/src/cblat2.c +++ b/blastest/src/cblat2.c @@ -1,4 +1,7 @@ /* cblat2.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5451,7 +5454,7 @@ real sdiff_(real *x, real *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5515,7 +5518,7 @@ real sdiff_(real *x, real *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/cblat3.c b/blastest/src/cblat3.c index a5b870f0f3..549f7828ff 100644 --- a/blastest/src/cblat3.c +++ b/blastest/src/cblat3.c @@ -1,4 +1,7 @@ /* cblat3.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5815,7 +5818,7 @@ real sdiff_(real *x, real *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5881,7 +5884,7 @@ real sdiff_(real *x, real *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/dblat1.c b/blastest/src/dblat1.c index 14665d844f..945cfaacb8 100644 --- a/blastest/src/dblat1.c +++ b/blastest/src/dblat1.c @@ -1034,7 +1034,7 @@ static real c_b81 = 0.f; /* Local variables */ real sd; - extern real s_epsilon_(); + extern real s_epsilon_(real *); /* Fortran I/O blocks */ static cilist io___125 = { 0, 6, 0, fmt_99999, 0 }; diff --git a/blastest/src/dblat2.c b/blastest/src/dblat2.c index 0cdc8f16f3..1f00b0c53d 100644 --- a/blastest/src/dblat2.c +++ b/blastest/src/dblat2.c @@ -1,4 +1,7 @@ /* dblat2.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5143,7 +5146,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5207,7 +5210,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/dblat3.c b/blastest/src/dblat3.c index d7a85e29c1..dfdad1f474 100644 --- a/blastest/src/dblat3.c +++ b/blastest/src/dblat3.c @@ -1,4 +1,7 @@ /* dblat3.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -4563,7 +4566,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -4629,7 +4632,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/sblat2.c b/blastest/src/sblat2.c index 54d0a010af..6b974a605c 100644 --- a/blastest/src/sblat2.c +++ b/blastest/src/sblat2.c @@ -1,4 +1,7 @@ /* sblat2.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5105,7 +5108,7 @@ real sdiff_(real *x, real *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5169,7 +5172,7 @@ real sdiff_(real *x, real *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/sblat3.c b/blastest/src/sblat3.c index dc5ef5738b..e018df8eb1 100644 --- a/blastest/src/sblat3.c +++ b/blastest/src/sblat3.c @@ -1,4 +1,7 @@ /* sblat3.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -4538,7 +4541,7 @@ real sdiff_(real *x, real *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -4604,7 +4607,7 @@ real sdiff_(real *x, real *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/zblat2.c b/blastest/src/zblat2.c index 030f03b833..4894addff8 100644 --- a/blastest/src/zblat2.c +++ b/blastest/src/zblat2.c @@ -1,4 +1,7 @@ /* zblat2.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5500,7 +5503,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5564,7 +5567,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blastest/src/zblat3.c b/blastest/src/zblat3.c index 3ff3634b68..45e37e5851 100644 --- a/blastest/src/zblat3.c +++ b/blastest/src/zblat3.c @@ -1,4 +1,7 @@ /* zblat3.f -- translated by f2c (version 20100827). + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + You must link the resulting object file with libf2c: on Microsoft Windows system, link with libf2c.lib; on Linux or Unix systems, link with .../path/to/libf2c.a -lm @@ -5850,7 +5853,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) } /* chkxer_ */ -/* Subroutine */ int xerbla_(char *srname, integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_(char *srname, integer *info, ftnlen srname_len) { /* Format strings */ static char fmt_9999[] = "(\002 ******* XERBLA WAS CALLED WITH INFO =" @@ -5916,7 +5919,7 @@ doublereal ddiff_(doublereal *x, doublereal *y) e_wsfe(); infoc_2.ok = FALSE_; } - return 0; + return; /* End of XERBLA */ diff --git a/blis.pc.in b/blis.pc.in index 57dbafec45..b507b314d2 100644 --- a/blis.pc.in +++ b/blis.pc.in @@ -6,6 +6,6 @@ includedir=@includedir@ Name: BLIS Description: BLAS-like Library Instantiation Software Framework Version: @PACKAGE_VERSION@ -Libs: -L${libdir} -lblis +Libs: -L${libdir} -l@AOCLLIB@ Libs.private: @LDFLAGS@ Cflags: -I${includedir}/blis diff --git a/build/cmake/aocl-blas.pc.in b/build/cmake/aocl-blas.pc.in new file mode 100644 index 0000000000..6279740c37 --- /dev/null +++ b/build/cmake/aocl-blas.pc.in @@ -0,0 +1,11 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@LIB_DIR@ +includedir=@INCLUDE_DIR@ + +Name: AOCL-BLAS +Description: BLAS-like Library Instantiation Software Framework +Version: @VERSION_STRING@ +Libs: -L${libdir} -l@LIBBLIS@ +Libs.private: @LDFLAGS_STRING@ +Cflags: -I${includedir} diff --git a/build/cmake/bli_addon.h.in b/build/cmake/bli_addon.h.in index b002b43619..cd21e85e36 100644 --- a/build/cmake/bli_addon.h.in +++ b/build/cmake/bli_addon.h.in @@ -1,6 +1,37 @@ /* - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - */ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ #ifndef BLIS_ADDON_H #define BLIS_ADDON_H diff --git a/build/cmake/bli_config.h.in b/build/cmake/bli_config.h.in index aed543b868..0cacfef83e 100644 --- a/build/cmake/bli_config.h.in +++ b/build/cmake/bli_config.h.in @@ -1,183 +1,213 @@ -/* - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - */ - -#ifndef BLIS_CONFIG_H -#define BLIS_CONFIG_H - -// Enabled configuration "family" (config_name) -${CONFIG_NAME_DEFINE} - -// Enabled sub-configurations (config_list) -${CONFIG_LIST_DEFINES} - -// Enabled kernel sets (kernel_list) -${KERNEL_LIST_DEFINES} - -//This macro is enabled only for ZEN family configurations. -//This enables us to use different cache-blocking sizes for TRSM instead of common level-3 cache-block sizes. -#if ${ENABLE_AOCL_ZEN_01} -#define AOCL_BLIS_ZEN -#endif - -#if ${ENABLE_AOCL_DYNAMIC_01} -#define AOCL_DYNAMIC -#endif - -#if ${ENABLE_SYSTEM_01} -#define BLIS_ENABLE_SYSTEM -#else -#define BLIS_DISABLE_SYSTEM -#endif - -#if ${ENABLE_OPENMP_01} -#define BLIS_ENABLE_OPENMP -#endif - -#if ${ENABLE_PTHREADS_01} -#define BLIS_ENABLE_PTHREADS -#endif - -#if ${ENABLE_JRIR_SLAB_01} -#define BLIS_ENABLE_JRIR_SLAB -#endif - -#if ${ENABLE_JRIR_RR_01} -#define BLIS_ENABLE_JRIR_RR -#endif - -#if ${ENABLE_PBA_POOLS_01} -#define BLIS_ENABLE_PBA_POOLS -#else -#define BLIS_DISABLE_PBA_POOLS -#endif - -#if ${ENABLE_SBA_POOLS_01} -#define BLIS_ENABLE_SBA_POOLS -#else -#define BLIS_DISABLE_SBA_POOLS -#endif - -#if ${ENABLE_MEM_TRACING_01} -#define BLIS_ENABLE_MEM_TRACING -#else -#define BLIS_DISABLE_MEM_TRACING -#endif - -#if ${INT_TYPE_SIZE} == 64 -#define BLIS_INT_TYPE_SIZE 64 -#elif ${INT_TYPE_SIZE} == 32 -#define BLIS_INT_TYPE_SIZE 32 -#else -// determine automatically -#endif - -#if ${BLAS_INT_TYPE_SIZE} == 64 -#define BLIS_BLAS_INT_TYPE_SIZE 64 -#elif ${BLAS_INT_TYPE_SIZE} == 32 -#define BLIS_BLAS_INT_TYPE_SIZE 32 -#else -// determine automatically -#endif - -#ifndef BLIS_ENABLE_BLAS -#ifndef BLIS_DISABLE_BLAS -#if ${ENABLE_BLAS_01} -#define BLIS_ENABLE_BLAS -#else -#define BLIS_DISABLE_BLAS -#endif -#endif -#endif - -#ifndef BLIS_ENABLE_CBLAS -#ifndef BLIS_DISABLE_CBLAS -#if ${ENABLE_CBLAS_01} -#define BLIS_ENABLE_CBLAS -#else -#define BLIS_DISABLE_CBLAS -#endif -#endif -#endif - -// If the CBLAS compatibility layer was enabled while the BLAS layer -// was not enabled, we must enable the BLAS layer here. Also undefine -// BLIS_DISABLE_BLAS to ensure consistency. -#ifdef BLIS_ENABLE_CBLAS -#ifndef BLIS_ENABLE_BLAS -#define BLIS_ENABLE_BLAS -#endif -#undef BLIS_DISABLE_BLAS -#endif // BLIS_ENABLE_CBLAS - -#ifndef BLIS_ENABLE_MIXED_DT -#ifndef BLIS_DISABLE_MIXED_DT -#if ${ENABLE_MIXED_DT_01} -#define BLIS_ENABLE_MIXED_DT -#else -#define BLIS_DISABLE_MIXED_DT -#endif -#endif -#endif - -#ifndef BLIS_ENABLE_MIXED_DT_EXTRA_MEM -#ifndef BLIS_DISABLE_MIXED_DT_EXTRA_MEM -#if ${ENABLE_MIXED_DT_EXTRA_MEM_01} -#define BLIS_ENABLE_MIXED_DT_EXTRA_MEM -#else -#define BLIS_DISABLE_MIXED_DT_EXTRA_MEM -#endif -#endif -#endif - -#if ${ENABLE_SUP_HANDLING_01} -#define BLIS_ENABLE_SUP_HANDLING -#else -#define BLIS_DISABLE_SUP_HANDLING -#endif - -#if ${ENABLE_MEMKIND_01} -#define BLIS_ENABLE_MEMKIND -#else -#define BLIS_DISABLE_MEMKIND -#endif - -#if ${ENABLE_TRSM_PREINVERSION_01} -#define BLIS_ENABLE_TRSM_PREINVERSION -#else -#define BLIS_DISABLE_TRSM_PREINVERSION -#endif - -#if ${ENABLE_PRAGMA_OMP_SIMD_01} -#define BLIS_ENABLE_PRAGMA_OMP_SIMD -#else -#define BLIS_DISABLE_PRAGMA_OMP_SIMD -#endif - -#if ${ENABLE_SANDBOX_01} -#define BLIS_ENABLE_SANDBOX -#else -#define BLIS_DISABLE_SANDBOX -#endif - -#if ${ENABLE_SHARED_01} -#define BLIS_ENABLE_SHARED -#else -#define BLIS_DISABLE_SHARED -#endif - -#if ${COMPLEX_RETURN_INTEL_01} -#define BLIS_ENABLE_COMPLEX_RETURN_INTEL -#else -#define BLIS_DISABLE_COMPLEX_RETURN_INTEL -#endif - -#if ${DISABLE_BLIS_ARCH_TYPE_01} -#define DISABLE_BLIS_ARCH_TYPE -#define DISABLE_BLIS_MODEL_TYPE -#endif - -#define __blis_arch_type_name "${RENAME_BLIS_ARCH_TYPE}" -#define __blis_model_type_name "${RENAME_BLIS_MODEL_TYPE}" - -#endif +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_CONFIG_H +#define BLIS_CONFIG_H + +// Enabled configuration "family" (config_name) +${CONFIG_NAME_DEFINE} + +// Enabled sub-configurations (config_list) +${CONFIG_LIST_DEFINES} + +// Enabled kernel sets (kernel_list) +${KERNEL_LIST_DEFINES} + +//This macro is enabled only for ZEN family configurations. +//This enables us to use different cache-blocking sizes for TRSM instead of common level-3 cache-block sizes. +#if ${ENABLE_AOCL_ZEN_01} +#define AOCL_BLIS_ZEN +#endif + +#if ${ENABLE_AOCL_DYNAMIC_01} +#define AOCL_DYNAMIC +#endif + +#if ${ENABLE_SYSTEM_01} +#define BLIS_ENABLE_SYSTEM +#else +#define BLIS_DISABLE_SYSTEM +#endif + +#if ${ENABLE_OPENMP_01} +#define BLIS_ENABLE_OPENMP +#endif + +#if ${ENABLE_PTHREADS_01} +#define BLIS_ENABLE_PTHREADS +#endif + +#if ${ENABLE_JRIR_SLAB_01} +#define BLIS_ENABLE_JRIR_SLAB +#endif + +#if ${ENABLE_JRIR_RR_01} +#define BLIS_ENABLE_JRIR_RR +#endif + +#if ${ENABLE_PBA_POOLS_01} +#define BLIS_ENABLE_PBA_POOLS +#else +#define BLIS_DISABLE_PBA_POOLS +#endif + +#if ${ENABLE_SBA_POOLS_01} +#define BLIS_ENABLE_SBA_POOLS +#else +#define BLIS_DISABLE_SBA_POOLS +#endif + +#if ${ENABLE_MEM_TRACING_01} +#define BLIS_ENABLE_MEM_TRACING +#else +#define BLIS_DISABLE_MEM_TRACING +#endif + +#if ${INT_TYPE_SIZE} == 64 +#define BLIS_INT_TYPE_SIZE 64 +#elif ${INT_TYPE_SIZE} == 32 +#define BLIS_INT_TYPE_SIZE 32 +#else +// determine automatically +#endif + +#if ${BLAS_INT_TYPE_SIZE} == 64 +#define BLIS_BLAS_INT_TYPE_SIZE 64 +#elif ${BLAS_INT_TYPE_SIZE} == 32 +#define BLIS_BLAS_INT_TYPE_SIZE 32 +#else +// determine automatically +#endif + +#ifndef BLIS_ENABLE_BLAS +#ifndef BLIS_DISABLE_BLAS +#if ${ENABLE_BLAS_01} +#define BLIS_ENABLE_BLAS +#else +#define BLIS_DISABLE_BLAS +#endif +#endif +#endif + +#ifndef BLIS_ENABLE_CBLAS +#ifndef BLIS_DISABLE_CBLAS +#if ${ENABLE_CBLAS_01} +#define BLIS_ENABLE_CBLAS +#else +#define BLIS_DISABLE_CBLAS +#endif +#endif +#endif + +// If the CBLAS compatibility layer was enabled while the BLAS layer +// was not enabled, we must enable the BLAS layer here. Also undefine +// BLIS_DISABLE_BLAS to ensure consistency. +#ifdef BLIS_ENABLE_CBLAS +#ifndef BLIS_ENABLE_BLAS +#define BLIS_ENABLE_BLAS +#endif +#undef BLIS_DISABLE_BLAS +#endif // BLIS_ENABLE_CBLAS + +#ifndef BLIS_ENABLE_MIXED_DT +#ifndef BLIS_DISABLE_MIXED_DT +#if ${ENABLE_MIXED_DT_01} +#define BLIS_ENABLE_MIXED_DT +#else +#define BLIS_DISABLE_MIXED_DT +#endif +#endif +#endif + +#ifndef BLIS_ENABLE_MIXED_DT_EXTRA_MEM +#ifndef BLIS_DISABLE_MIXED_DT_EXTRA_MEM +#if ${ENABLE_MIXED_DT_EXTRA_MEM_01} +#define BLIS_ENABLE_MIXED_DT_EXTRA_MEM +#else +#define BLIS_DISABLE_MIXED_DT_EXTRA_MEM +#endif +#endif +#endif + +#if ${ENABLE_SUP_HANDLING_01} +#define BLIS_ENABLE_SUP_HANDLING +#else +#define BLIS_DISABLE_SUP_HANDLING +#endif + +#if ${ENABLE_MEMKIND_01} +#define BLIS_ENABLE_MEMKIND +#else +#define BLIS_DISABLE_MEMKIND +#endif + +#if ${ENABLE_TRSM_PREINVERSION_01} +#define BLIS_ENABLE_TRSM_PREINVERSION +#else +#define BLIS_DISABLE_TRSM_PREINVERSION +#endif + +#if ${ENABLE_PRAGMA_OMP_SIMD_01} +#define BLIS_ENABLE_PRAGMA_OMP_SIMD +#else +#define BLIS_DISABLE_PRAGMA_OMP_SIMD +#endif + +#if ${ENABLE_SANDBOX_01} +#define BLIS_ENABLE_SANDBOX +#else +#define BLIS_DISABLE_SANDBOX +#endif + +#if ${ENABLE_SHARED_01} +#define BLIS_ENABLE_SHARED +#else +#define BLIS_DISABLE_SHARED +#endif + +#if ${COMPLEX_RETURN_INTEL_01} +#define BLIS_ENABLE_COMPLEX_RETURN_INTEL +#else +#define BLIS_DISABLE_COMPLEX_RETURN_INTEL +#endif + +#if ${DISABLE_BLIS_ARCH_TYPE_01} +#define DISABLE_BLIS_ARCH_TYPE +#define DISABLE_BLIS_MODEL_TYPE +#endif + +#define __blis_arch_type_name "${RENAME_BLIS_ARCH_TYPE}" +#define __blis_model_type_name "${RENAME_BLIS_MODEL_TYPE}" + +#endif diff --git a/build/cmake/check-blastest.py b/build/cmake/check-blastest.py index 8e1123cf80..e57f7764e4 100644 --- a/build/cmake/check-blastest.py +++ b/build/cmake/check-blastest.py @@ -22,10 +22,13 @@ def check_blastest(): if has_failure: print("\033[0;31m At least one BLAS test failed. :( \033[0m") print("\033[0;31m Please see the corresponding out.* for details. \033[0m") + exit(1) elif is_empty: print("\033[0;31m At least one BLAS test resulted without a PASS. :( \033[0m") print("\033[0;31m Please ensure that the corresponding out.* was generated correctly. \033[0m") + exit(1) else: print("\033[0;32m All BLAS tests passed! \033[0m") + exit(0) check_blastest() diff --git a/build/cmake/check-blistest.py b/build/cmake/check-blistest.py index 983f8e8241..e2679771bc 100644 --- a/build/cmake/check-blistest.py +++ b/build/cmake/check-blistest.py @@ -13,10 +13,12 @@ def check_blistest(): if "FAILURE" in content: print("\033[0;31m At least one BLIS test failed. :( \033[0m") print("\033[0;31m Please see the corresponding output.testsuite* for details. \033[0m") + exit(1) elif not "PASS" in content: print("\033[0;31m No BLIS test resulted in PASS. :( \033[0m") print("\033[0;31m Please ensure that the corresponding output.testsuite* was generated correctly. \033[0m") + exit(1) else: print("\033[0;32m All BLIS tests passed! \033[0m") - + exit(0) check_blistest() diff --git a/build/cmake/config_print.py b/build/cmake/config_print.py index f5fc767711..8252c49631 100644 --- a/build/cmake/config_print.py +++ b/build/cmake/config_print.py @@ -1,4 +1,4 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.## # Import modules import os @@ -21,7 +21,7 @@ def main(): print(" ") print(" confname The name of the sub-directory inside of the 'config'") print(" directory containing the desired BLIS configuration.") - print(" Currently, only amdzen, zen, zen2, zen3, zen4 and generic") + print(" Currently, only amdzen, zen, zen2, zen3, zen4, zen5 and generic") print(" configuration options are supported.") print(" Note that confname MUST be specified; if it is not,") print(" configure will complain. To build a completely generic") @@ -43,17 +43,16 @@ def main(): print( " kept in the framework, otherwise optimization is" ) print( " turned off. Available options are 'opt', 'noopt' and 'off'." ) print( " " ) - print( " --disable-static, --enable-static" ) + print( " -DBUILD_SHARED_LIBS=ON or -DBUILD_SHARED_LIBS=OFF" ) print( " " ) - print( " Disable (enabled by default) building BLIS as a static" ) - print( " library. If the static library build is disabled, the" ) - print( " shared library build must remain enabled." ) + print( " Enable building the shared BLIS library (default)." ) + print( " If the shared library build is disabled, the static library" ) + print( " is built." ) print( " " ) - print( " --disable-shared, --enable-shared" ) + print( " -DBUILD_STATIC_LIBS=ON or -DBUILD_STATIC_LIBS=OFF" ) print( " " ) - print( " Disable (enabled by default) building BLIS as a shared" ) - print( " library. If the shared library build is disabled, the" ) - print( " static library build must remain enabled." ) + print( " Enable building the static BLIS library (default) (Linux only)." ) + print( " On Linux, we can have builds for both shared and static libraries." ) print( " " ) print( " -DEXPORT_SHARED=[SYMBOLS]" ) print( " " ) @@ -285,6 +284,17 @@ def main(): print( " " ) print( " Export APIs with uppercase" ) print( " " ) + print( " -DENABLE_COVERAGE=ON or -DENABLE_COVERAGE=OFF" ) + print( " " ) + print( " Enable (disabled by default) generation of code coverage" ) + print( " report in html format. Code coverage support is provided" ) + print( " only on LINUX with GCC compiler." ) + print( " " ) + print( " -DENABLE_ASAN=ON or -DENABLE_ASAN=OFF" ) + print( " " ) + print( " Enable (disabled by default) Address Sanitizer to find " ) + print( " memory access error. Address Sanitizer support is provided" ) + print( " only on LINUX with Clang compiler" ) print( " " ) print( " Additional CMake Variables:" ) print( " " ) diff --git a/build/cmake/presets/base.json b/build/cmake/presets/base.json new file mode 100644 index 0000000000..4225b0b6dd --- /dev/null +++ b/build/cmake/presets/base.json @@ -0,0 +1,103 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "lp64", + "hidden": true, + "cacheVariables": { + "INT_SIZE": "32", + "BLAS_INT_SIZE": "32" + } + }, + { + "name": "ilp64", + "hidden": true, + "cacheVariables": { + "INT_SIZE": "64", + "BLAS_INT_SIZE": "64" + } + }, + { + "name": "st", + "hidden": true, + "cacheVariables": { + "ENABLE_THREADING": "no" + } + }, + { + "name": "mt", + "hidden": true, + "cacheVariables": { + "ENABLE_THREADING": "openmp" + } + }, + { + "name": "amdzen", + "hidden": true, + "cacheVariables": { + "BLIS_CONFIG_FAMILY": "amdzen" + } + }, + { + "name": "auto", + "hidden": true, + "cacheVariables": { + "BLIS_CONFIG_FAMILY": "auto" + } + }, + { + "name": "zen5", + "hidden": true, + "cacheVariables": { + "BLIS_CONFIG_FAMILY": "zen5" + } + }, + { + "name": "static", + "hidden": true, + "cacheVariables": { + "BUILD_SHARED_LIBS": "OFF" + } + }, + { + "name": "shared", + "hidden": true, + "cacheVariables": { + "BUILD_SHARED_LIBS": "ON" + } + }, + { + "name": "linux-static", + "description": "Build both static and shared libs on Linux but test with static.", + "hidden": true, + "cacheVariables": { + "TEST_WITH_SHARED": "OFF" + } + }, + { + "name": "linux-shared", + "description": "Build both static and shared libs on Linux but test with shared.", + "hidden": true, + "cacheVariables": { + "TEST_WITH_SHARED": "ON" + } + }, + { + "name": "base", + "hidden": true, + "cacheVariables": { + "ENABLE_CBLAS": "ON" + }, + "binaryDir": "${sourceDir}/build-${presetName}" + } + ], + "buildPresets": [ + { + "name": "base", + "configurePreset": "base", + "targets": "install", + "configuration": "Release", + "jobs": 0 + } + ] +} diff --git a/build/cmake/presets/linux-make-clang.json b/build/cmake/presets/linux-make-clang.json new file mode 100644 index 0000000000..aeec17bf14 --- /dev/null +++ b/build/cmake/presets/linux-make-clang.json @@ -0,0 +1,921 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-make-clang", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "ENABLE_ADDON": "aocl_gemm", + "COMPLEX_RETURN": "intel", + "CMAKE_C_COMPILER": "clang", + "CMAKE_CXX_COMPILER": "clang++" + }, + "generator": "Unix Makefiles", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-make-clang-st-lp64-amdzen-static", + "inherits": ["linux-make-clang", "st", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-lp64-amdzen-shared", + "inherits": ["linux-make-clang", "st", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-static", + "inherits": ["linux-make-clang", "mt", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-shared", + "inherits": ["linux-make-clang", "mt", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-static", + "inherits": ["linux-make-clang", "st", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-shared", + "inherits": ["linux-make-clang", "st", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-static", + "inherits": ["linux-make-clang", "mt", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-shared", + "inherits": ["linux-make-clang", "mt", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + + { + "name": "linux-make-clang-st-lp64-auto-static", + "inherits": ["linux-make-clang", "st", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-lp64-auto-shared", + "inherits": ["linux-make-clang", "st", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-auto-static", + "inherits": ["linux-make-clang", "mt", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-auto-shared", + "inherits": ["linux-make-clang", "mt", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-auto-static", + "inherits": ["linux-make-clang", "st", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-auto-shared", + "inherits": ["linux-make-clang", "st", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-auto-static", + "inherits": ["linux-make-clang", "mt", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-auto-shared", + "inherits": ["linux-make-clang", "mt", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-lp64-zen5-static", + "inherits": ["linux-make-clang", "st", "lp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-lp64-zen5-shared", + "inherits": ["linux-make-clang", "st", "lp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-zen5-static", + "inherits": ["linux-make-clang", "mt", "lp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-lp64-zen5-shared", + "inherits": ["linux-make-clang", "mt", "lp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-zen5-static", + "inherits": ["linux-make-clang", "st", "ilp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-st-ilp64-zen5-shared", + "inherits": ["linux-make-clang", "st", "ilp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-static", + "inherits": ["linux-make-clang", "mt", "ilp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-shared", + "inherits": ["linux-make-clang", "mt", "ilp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + } + ], + "buildPresets": [ + { + "name": "linux-make-clang-st-lp64-amdzen-static", + "configurePreset": "linux-make-clang-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-lp64-amdzen-shared", + "configurePreset": "linux-make-clang-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-static", + "configurePreset": "linux-make-clang-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-shared", + "configurePreset": "linux-make-clang-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-static", + "configurePreset": "linux-make-clang-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-shared", + "configurePreset": "linux-make-clang-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-static", + "configurePreset": "linux-make-clang-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-shared", + "configurePreset": "linux-make-clang-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-lp64-auto-static", + "configurePreset": "linux-make-clang-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-lp64-auto-shared", + "configurePreset": "linux-make-clang-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-auto-static", + "configurePreset": "linux-make-clang-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-auto-shared", + "configurePreset": "linux-make-clang-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-auto-static", + "configurePreset": "linux-make-clang-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-auto-shared", + "configurePreset": "linux-make-clang-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-auto-static", + "configurePreset": "linux-make-clang-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-auto-shared", + "configurePreset": "linux-make-clang-mt-ilp64-auto-shared", + "inherits": "base" + }, + + { + "name": "linux-make-clang-st-lp64-zen5-static", + "configurePreset": "linux-make-clang-st-lp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-lp64-zen5-shared", + "configurePreset": "linux-make-clang-st-lp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-zen5-static", + "configurePreset": "linux-make-clang-mt-lp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-lp64-zen5-shared", + "configurePreset": "linux-make-clang-mt-lp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-zen5-static", + "configurePreset": "linux-make-clang-st-ilp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-st-ilp64-zen5-shared", + "configurePreset": "linux-make-clang-st-ilp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-static", + "configurePreset": "linux-make-clang-mt-ilp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-shared", + "configurePreset": "linux-make-clang-mt-ilp64-zen5-shared", + "inherits": "base" + }, + + { + "name": "linux-make-clang-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-st-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-st-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-st-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-st-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-lp64-zen5-static-check", + "description": "Check static single-threaded LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-st-lp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-lp64-zen5-shared-check", + "description": "Check shared single-threaded LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-st-lp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-zen5-static-check", + "description": "Check multithreaded static LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-lp64-zen5-shared-check", + "description": "Check multithreaded shared LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-mt-lp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-zen5-static-check", + "description": "Check single-threaded static ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-st-ilp64-zen5-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-st-ilp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-static-check", + "description": "Check multithreaded static ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-clang-mt-ilp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "linux-make-clang-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "linux-make-clang-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-clang-st-lp64-zen5-static", + "description": "Build and check single-threaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-lp64-zen5-shared", + "description": "Build and check single-threaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-lp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-zen5-static", + "description": "Build and check multithreaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-lp64-zen5-shared", + "description": "Build and check multithreaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-lp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-zen5-static", + "description": "Build and check single-threaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-clang-st-ilp64-zen5-shared", + "description": "Build and check single-threaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-st-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-st-ilp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-static", + "description": "Build and check multithreaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-clang-mt-ilp64-zen5-shared", + "description": "Build and check multithreaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-clang-mt-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-clang-mt-ilp64-zen5-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/presets/linux-make-gcc.json b/build/cmake/presets/linux-make-gcc.json new file mode 100644 index 0000000000..99a4664471 --- /dev/null +++ b/build/cmake/presets/linux-make-gcc.json @@ -0,0 +1,921 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-make-gcc", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "ENABLE_ADDON": "aocl_gemm", + "COMPLEX_RETURN": "gnu", + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++" + }, + "generator": "Unix Makefiles", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-make-gcc-st-lp64-amdzen-static", + "inherits": ["linux-make-gcc", "st", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-lp64-amdzen-shared", + "inherits": ["linux-make-gcc", "st", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-static", + "inherits": ["linux-make-gcc", "mt", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-shared", + "inherits": ["linux-make-gcc", "mt", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-static", + "inherits": ["linux-make-gcc", "st", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-shared", + "inherits": ["linux-make-gcc", "st", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-static", + "inherits": ["linux-make-gcc", "mt", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-shared", + "inherits": ["linux-make-gcc", "mt", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-lp64-zen5-static", + "inherits": ["linux-make-gcc", "st", "lp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-lp64-zen5-shared", + "inherits": ["linux-make-gcc", "st", "lp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-static", + "inherits": ["linux-make-gcc", "mt", "lp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-shared", + "inherits": ["linux-make-gcc", "mt", "lp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-static", + "inherits": ["linux-make-gcc", "st", "ilp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-shared", + "inherits": ["linux-make-gcc", "st", "ilp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-static", + "inherits": ["linux-make-gcc", "mt", "ilp64", "zen5", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-shared", + "inherits": ["linux-make-gcc", "mt", "ilp64", "zen5", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-zen5" + }, + "hidden": false + }, + + { + "name": "linux-make-gcc-st-lp64-auto-static", + "inherits": ["linux-make-gcc", "st", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-lp64-auto-shared", + "inherits": ["linux-make-gcc", "st", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-auto-static", + "inherits": ["linux-make-gcc", "mt", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-lp64-auto-shared", + "inherits": ["linux-make-gcc", "mt", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-auto-static", + "inherits": ["linux-make-gcc", "st", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-st-ilp64-auto-shared", + "inherits": ["linux-make-gcc", "st", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-static", + "inherits": ["linux-make-gcc", "mt", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-shared", + "inherits": ["linux-make-gcc", "mt", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + } + ], + "buildPresets": [ + { + "name": "linux-make-gcc-st-lp64-amdzen-static", + "configurePreset": "linux-make-gcc-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-lp64-amdzen-shared", + "configurePreset": "linux-make-gcc-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-static", + "configurePreset": "linux-make-gcc-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-shared", + "configurePreset": "linux-make-gcc-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-static", + "configurePreset": "linux-make-gcc-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-shared", + "configurePreset": "linux-make-gcc-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-static", + "configurePreset": "linux-make-gcc-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-shared", + "configurePreset": "linux-make-gcc-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-lp64-zen5-static", + "configurePreset": "linux-make-gcc-st-lp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-lp64-zen5-shared", + "configurePreset": "linux-make-gcc-st-lp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-static", + "configurePreset": "linux-make-gcc-mt-lp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-shared", + "configurePreset": "linux-make-gcc-mt-lp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-static", + "configurePreset": "linux-make-gcc-st-ilp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-shared", + "configurePreset": "linux-make-gcc-st-ilp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-static", + "configurePreset": "linux-make-gcc-mt-ilp64-zen5-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-shared", + "configurePreset": "linux-make-gcc-mt-ilp64-zen5-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-lp64-auto-static", + "configurePreset": "linux-make-gcc-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-lp64-auto-shared", + "configurePreset": "linux-make-gcc-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-auto-static", + "configurePreset": "linux-make-gcc-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-lp64-auto-shared", + "configurePreset": "linux-make-gcc-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-auto-static", + "configurePreset": "linux-make-gcc-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-st-ilp64-auto-shared", + "configurePreset": "linux-make-gcc-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-static", + "configurePreset": "linux-make-gcc-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-shared", + "configurePreset": "linux-make-gcc-mt-ilp64-auto-shared", + "inherits": "base" + }, + + { + "name": "linux-make-gcc-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-lp64-zen5-static-check", + "description": "Check static single-threaded LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-lp64-zen5-shared-check", + "description": "Check shared single-threaded LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-static-check", + "description": "Check multithreaded static LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-shared-check", + "description": "Check multithreaded shared LP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-static-check", + "description": "Check single-threaded static ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-static-check", + "description": "Check multithreaded static ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-zen5-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with zen5 option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-zen5-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-st-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-mt-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-st-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-gcc-mt-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "linux-make-gcc-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "linux-make-gcc-st-lp64-zen5-static", + "description": "Build and check single-threaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-lp64-zen5-shared", + "description": "Build and check single-threaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-static", + "description": "Build and check multithreaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-zen5-shared", + "description": "Build and check multithreaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-static", + "description": "Build and check single-threaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-zen5-shared", + "description": "Build and check single-threaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-zen5-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-static", + "description": "Build and check multithreaded static BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-zen5-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-zen5-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-zen5-shared", + "description": "Build and check multithreaded shared BLIS for zen5 configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-zen5-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-zen5-shared-check" + } + ] + }, + + { + "name": "linux-make-gcc-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-gcc-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-gcc-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-gcc-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-gcc-mt-ilp64-auto-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/presets/linux-make.json b/build/cmake/presets/linux-make.json new file mode 100644 index 0000000000..084a07bef0 --- /dev/null +++ b/build/cmake/presets/linux-make.json @@ -0,0 +1,621 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-make", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "ENABLE_ADDON": "aocl_gemm" + }, + "generator": "Unix Makefiles", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-make-st-lp64-amdzen-static", + "inherits": ["linux-make", "st", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-st-lp64-amdzen-shared", + "inherits": ["linux-make", "st", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-mt-lp64-amdzen-static", + "inherits": ["linux-make", "mt", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-mt-lp64-amdzen-shared", + "inherits": ["linux-make", "mt", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-st-ilp64-amdzen-static", + "inherits": ["linux-make", "st", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-st-ilp64-amdzen-shared", + "inherits": ["linux-make", "st", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-mt-ilp64-amdzen-static", + "inherits": ["linux-make", "mt", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared", + "inherits": ["linux-make", "mt", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + + { + "name": "linux-make-st-lp64-auto-static", + "inherits": ["linux-make", "st", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-st-lp64-auto-shared", + "inherits": ["linux-make", "st", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-mt-lp64-auto-static", + "inherits": ["linux-make", "mt", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-mt-lp64-auto-shared", + "inherits": ["linux-make", "mt", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-st-ilp64-auto-static", + "inherits": ["linux-make", "st", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-st-ilp64-auto-shared", + "inherits": ["linux-make", "st", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-mt-ilp64-auto-static", + "inherits": ["linux-make", "mt", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-make-mt-ilp64-auto-shared", + "inherits": ["linux-make", "mt", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + } + ], + "buildPresets": [ + { + "name": "linux-make-st-lp64-amdzen-static", + "configurePreset": "linux-make-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-amdzen-shared", + "configurePreset": "linux-make-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-amdzen-static", + "configurePreset": "linux-make-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-amdzen-shared", + "configurePreset": "linux-make-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-amdzen-static", + "configurePreset": "linux-make-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-amdzen-shared", + "configurePreset": "linux-make-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-amdzen-static", + "configurePreset": "linux-make-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared", + "configurePreset": "linux-make-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-auto-static", + "configurePreset": "linux-make-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-auto-shared", + "configurePreset": "linux-make-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-auto-static", + "configurePreset": "linux-make-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-auto-shared", + "configurePreset": "linux-make-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-auto-static", + "configurePreset": "linux-make-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-auto-shared", + "configurePreset": "linux-make-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-auto-static", + "configurePreset": "linux-make-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-auto-shared", + "configurePreset": "linux-make-mt-ilp64-auto-shared", + "inherits": "base" + }, + + { + "name": "linux-make-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-st-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-st-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-mt-lp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-mt-lp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-st-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-st-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-mt-ilp64-amdzen-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-make-mt-ilp64-amdzen-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-st-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-st-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-mt-lp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-mt-lp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-st-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-st-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-mt-ilp64-auto-static", + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-make-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-make-mt-ilp64-auto-shared", + "targets": ["check", "checkblis-salt", "checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "linux-make-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-make-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "linux-make-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-st-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "linux-make-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-make-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-make-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-make-mt-ilp64-auto-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/presets/linux-ninja.json b/build/cmake/presets/linux-ninja.json new file mode 100644 index 0000000000..d249d7a938 --- /dev/null +++ b/build/cmake/presets/linux-ninja.json @@ -0,0 +1,637 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-ninja", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "ENABLE_ADDON": "aocl_gemm" + }, + "generator": "Ninja", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-ninja-st-lp64-amdzen-static", + "inherits": ["linux-ninja", "st", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared", + "inherits": ["linux-ninja", "st", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static", + "inherits": ["linux-ninja", "mt", "lp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared", + "inherits": ["linux-ninja", "mt", "lp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static", + "inherits": ["linux-ninja", "st", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared", + "inherits": ["linux-ninja", "st", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static", + "inherits": ["linux-ninja", "mt", "ilp64", "amdzen", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared", + "inherits": ["linux-ninja", "mt", "ilp64", "amdzen", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-amdzen" + }, + "hidden": false + }, + + { + "name": "linux-ninja-st-lp64-auto-static", + "inherits": ["linux-ninja", "st", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-lp64-auto-shared", + "inherits": ["linux-ninja", "st", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-lp64-auto-static", + "inherits": ["linux-ninja", "mt", "lp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-lp64-auto-shared", + "inherits": ["linux-ninja", "mt", "lp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-lp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-ilp64-auto-static", + "inherits": ["linux-ninja", "st", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-st-ilp64-auto-shared", + "inherits": ["linux-ninja", "st", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-ilp64-auto-static", + "inherits": ["linux-ninja", "mt", "ilp64", "auto", "linux-static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared", + "inherits": ["linux-ninja", "mt", "ilp64", "auto", "linux-shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-linux-ilp64-auto" + }, + "hidden": false + } + ], + "buildPresets": [ + { + "name": "linux-ninja-st-lp64-amdzen-static", + "configurePreset": "linux-ninja-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared", + "configurePreset": "linux-ninja-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static", + "configurePreset": "linux-ninja-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared", + "configurePreset": "linux-ninja-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static", + "configurePreset": "linux-ninja-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared", + "configurePreset": "linux-ninja-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-auto-static", + "configurePreset": "linux-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-auto-shared", + "configurePreset": "linux-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-auto-static", + "configurePreset": "linux-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-auto-shared", + "configurePreset": "linux-ninja-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-auto-static", + "configurePreset": "linux-ninja-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-auto-shared", + "configurePreset": "linux-ninja-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-auto-static", + "configurePreset": "linux-ninja-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared", + "configurePreset": "linux-ninja-mt-ilp64-auto-shared", + "inherits": "base" + }, + + { + "name": "linux-ninja-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-st-lp64-amdzen-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-st-lp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-mt-lp64-amdzen-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-mt-lp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-st-ilp64-amdzen-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-st-ilp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Linux", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-st-lp64-auto-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-st-lp64-auto-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-mt-lp64-auto-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-mt-lp64-auto-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-st-ilp64-auto-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-st-ilp64-auto-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-mt-ilp64-auto-static", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Linux", + "configurePreset": "linux-ninja-mt-ilp64-auto-shared", + "jobs": 1, + "targets": ["check", "checkblis-salt", "checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "linux-ninja-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "linux-ninja-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-ninja-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-ninja-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "linux-ninja-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "linux-ninja-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-ninja-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "linux-ninja-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-ninja-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "linux-ninja-mt-ilp64-auto-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/presets/win-msvc.json b/build/cmake/presets/win-msvc.json new file mode 100644 index 0000000000..43e7a36995 --- /dev/null +++ b/build/cmake/presets/win-msvc.json @@ -0,0 +1,624 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "win-msvc", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "COMPLEX_RETURN": "intel", + "ENABLE_NO_UNDERSCORE_API": "ON", + "OpenMP_libomp_LIBRARY": "$env{OpenMP_lib_path}/libiomp5md.lib" + }, + "generator": "Visual Studio 17 2022", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + }, + "toolset": "ClangCl" + }, + { + "name": "win-msvc-st-lp64-amdzen-static", + "inherits": ["win-msvc", "st", "lp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-st-lp64-amdzen-shared", + "inherits": ["win-msvc", "st", "lp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-lp64-amdzen-static", + "inherits": ["win-msvc", "mt", "lp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared", + "inherits": ["win-msvc", "mt", "lp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-st-ilp64-amdzen-static", + "inherits": ["win-msvc", "st", "ilp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared", + "inherits": ["win-msvc", "st", "ilp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static", + "inherits": ["win-msvc", "mt", "ilp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared", + "inherits": ["win-msvc", "mt", "ilp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-amdzen" + }, + "hidden": false + }, + + { + "name": "win-msvc-st-lp64-auto-static", + "inherits": ["win-msvc", "st", "lp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-st-lp64-auto-shared", + "inherits": ["win-msvc", "st", "lp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-lp64-auto-static", + "inherits": ["win-msvc", "mt", "lp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-lp64-auto-shared", + "inherits": ["win-msvc", "mt", "lp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-st-ilp64-auto-static", + "inherits": ["win-msvc", "st", "ilp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-st-ilp64-auto-shared", + "inherits": ["win-msvc", "st", "ilp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-ilp64-auto-static", + "inherits": ["win-msvc", "mt", "ilp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-msvc-mt-ilp64-auto-shared", + "inherits": ["win-msvc", "mt", "ilp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-msvc-ilp64-auto" + }, + "hidden": false + } + ], + "buildPresets": [ + + { + "name": "win-msvc-st-lp64-amdzen-static", + "configurePreset": "win-msvc-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-amdzen-shared", + "configurePreset": "win-msvc-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-amdzen-static", + "configurePreset": "win-msvc-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared", + "configurePreset": "win-msvc-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-amdzen-static", + "configurePreset": "win-msvc-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared", + "configurePreset": "win-msvc-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static", + "configurePreset": "win-msvc-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared", + "configurePreset": "win-msvc-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-auto-static", + "configurePreset": "win-msvc-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-auto-shared", + "configurePreset": "win-msvc-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-auto-static", + "configurePreset": "win-msvc-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-auto-shared", + "configurePreset": "win-msvc-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-auto-static", + "configurePreset": "win-msvc-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-auto-shared", + "configurePreset": "win-msvc-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-auto-static", + "configurePreset": "win-msvc-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-auto-shared", + "configurePreset": "win-msvc-mt-ilp64-auto-shared", + "inherits": "base" + } , + { + "name": "win-msvc-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-st-lp64-amdzen-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-st-lp64-amdzen-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-mt-lp64-amdzen-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-mt-lp64-amdzen-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-st-ilp64-amdzen-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-st-ilp64-amdzen-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-mt-ilp64-amdzen-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-msvc-mt-ilp64-amdzen-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-st-lp64-auto-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-st-lp64-auto-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-mt-lp64-auto-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-mt-lp64-auto-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-st-ilp64-auto-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-st-ilp64-auto-shared", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-mt-ilp64-auto-static", + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-msvc-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Windows", + "configurePreset": "win-msvc-mt-ilp64-auto-shared", + "targets": ["check", "testsuite/checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "win-msvc-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "win-msvc-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-msvc-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-msvc-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "win-msvc-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-lp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-auto-static-check" + } + ] + }, + { + "name": "win-msvc-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "win-msvc-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "win-msvc-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "win-msvc-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "win-msvc-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "win-msvc-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "win-msvc-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-msvc-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-msvc-mt-ilp64-auto-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/presets/win-ninja.json b/build/cmake/presets/win-ninja.json new file mode 100644 index 0000000000..a5123fbc6b --- /dev/null +++ b/build/cmake/presets/win-ninja.json @@ -0,0 +1,639 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "win-ninja", + "inherits": "base", + "hidden": true, + "cacheVariables": { + "COMPLEX_RETURN": "intel", + "ENABLE_NO_UNDERSCORE_API": "ON", + "CMAKE_C_COMPILER": "C:/Program Files/LLVM/bin/clang-cl.exe", + "CMAKE_CXX_COMPILER": "C:/Program Files/LLVM/bin/clang-cl.exe", + "OpenMP_libomp_LIBRARY": "$env{OpenMP_lib_path}/libiomp5md.lib" + }, + "generator": "Ninja", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "win-ninja-st-lp64-amdzen-static", + "inherits": ["win-ninja", "st", "lp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-st-lp64-amdzen-shared", + "inherits": ["win-ninja", "st", "lp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-lp64-amdzen-static", + "inherits": ["win-ninja", "mt", "lp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared", + "inherits": ["win-ninja", "mt", "lp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-st-ilp64-amdzen-static", + "inherits": ["win-ninja", "st", "ilp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared", + "inherits": ["win-ninja", "st", "ilp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static", + "inherits": ["win-ninja", "mt", "ilp64", "amdzen", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared", + "inherits": ["win-ninja", "mt", "ilp64", "amdzen", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-amdzen" + }, + "hidden": false + }, + { + "name": "win-ninja-st-lp64-auto-static", + "inherits": ["win-ninja", "st", "lp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-st-lp64-auto-shared", + "inherits": ["win-ninja", "st", "lp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-lp64-auto-static", + "inherits": ["win-ninja", "mt", "lp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-lp64-auto-shared", + "inherits": ["win-ninja", "mt", "lp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-lp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-st-ilp64-auto-static", + "inherits": ["win-ninja", "st", "ilp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-st-ilp64-auto-shared", + "inherits": ["win-ninja", "st", "ilp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-ilp64-auto-static", + "inherits": ["win-ninja", "mt", "ilp64", "auto", "static"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-auto" + }, + "hidden": false + }, + { + "name": "win-ninja-mt-ilp64-auto-shared", + "inherits": ["win-ninja", "mt", "ilp64", "auto", "shared"], + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/install-win-ninja-ilp64-auto" + }, + "hidden": false + } + ], + "buildPresets": [ + { + "name": "win-ninja-st-lp64-amdzen-static", + "configurePreset": "win-ninja-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-amdzen-shared", + "configurePreset": "win-ninja-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-amdzen-static", + "configurePreset": "win-ninja-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared", + "configurePreset": "win-ninja-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-amdzen-static", + "configurePreset": "win-ninja-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared", + "configurePreset": "win-ninja-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static", + "configurePreset": "win-ninja-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared", + "configurePreset": "win-ninja-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-auto-static", + "configurePreset": "win-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-auto-shared", + "configurePreset": "win-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-auto-static", + "configurePreset": "win-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-auto-shared", + "configurePreset": "win-ninja-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-auto-static", + "configurePreset": "win-ninja-st-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-auto-shared", + "configurePreset": "win-ninja-st-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-auto-static", + "configurePreset": "win-ninja-mt-ilp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-auto-shared", + "configurePreset": "win-ninja-mt-ilp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-amdzen-static-check", + "description": "Check static single-threaded LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-st-lp64-amdzen-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-lp64-amdzen-shared-check", + "description": "Check shared single-threaded LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-st-lp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-lp64-amdzen-static-check", + "description": "Check multithreaded static LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-mt-lp64-amdzen-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared-check", + "description": "Check multithreaded shared LP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-mt-lp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-ilp64-amdzen-static-check", + "description": "Check single-threaded static ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-st-ilp64-amdzen-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-st-ilp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static-check", + "description": "Check multithreaded static ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-mt-ilp64-amdzen-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with amdzen option on Windows", + "configurePreset": "win-ninja-mt-ilp64-amdzen-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-lp64-auto-static-check", + "description": "Check static single-threaded LP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-st-lp64-auto-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-lp64-auto-shared-check", + "description": "Check shared single-threaded LP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-st-lp64-auto-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-lp64-auto-static-check", + "description": "Check multithreaded static LP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-mt-lp64-auto-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-lp64-auto-shared-check", + "description": "Check multithreaded shared LP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-mt-lp64-auto-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-ilp64-auto-static-check", + "description": "Check single-threaded static ILP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-st-ilp64-auto-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-st-ilp64-auto-shared-check", + "description": "Check single-threaded shared ILP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-st-ilp64-auto-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-ilp64-auto-static-check", + "description": "Check multithreaded static ILP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-mt-ilp64-auto-static", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + }, + { + "name": "win-ninja-mt-ilp64-auto-shared-check", + "description": "Check multithreaded shared ILP64 BLIS with auto option on Windows", + "configurePreset": "win-ninja-mt-ilp64-auto-shared", + "jobs": 1, + "targets": ["check", "testsuite/checkblis-md"] + } + ], + "workflowPresets": [ + { + "name": "win-ninja-st-lp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-amdzen-static-check" + } + ] + }, + { + "name": "win-ninja-st-lp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-ninja-mt-lp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-amdzen-static-check" + } + ] + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-ninja-st-ilp64-amdzen-static", + "description": "Build and check single-threaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared", + "description": "Build and check single-threaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-amdzen-shared-check" + } + ] + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static", + "description": "Build and check multithreaded static BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-amdzen-static" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-amdzen-static-check" + } + ] + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared", + "description": "Build and check multithreaded shared BLIS for amdzen configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-amdzen-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-amdzen-shared-check" + } + ] + }, + + { + "name": "win-ninja-st-lp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-lp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-auto-static-check" + } + ] + }, + { + "name": "win-ninja-st-lp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-st-lp64-auto-shared-check" + } + ] + }, + { + "name": "win-ninja-mt-lp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-auto-static-check" + } + ] + }, + { + "name": "win-ninja-mt-lp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-lp64-auto-shared-check" + } + ] + }, + { + "name": "win-ninja-st-ilp64-auto-static", + "description": "Build and check single-threaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-auto-static-check" + } + ] + }, + { + "name": "win-ninja-st-ilp64-auto-shared", + "description": "Build and check single-threaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-st-ilp64-auto-shared-check" + } + ] + }, + { + "name": "win-ninja-mt-ilp64-auto-static", + "description": "Build and check multithreaded static BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-auto-static" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-auto-static-check" + } + ] + }, + { + "name": "win-ninja-mt-ilp64-auto-shared", + "description": "Build and check multithreaded shared BLIS for auto configuration on Windows", + "steps": [ + { + "type": "configure", + "name": "win-ninja-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-auto-shared" + }, + { + "type": "build", + "name": "win-ninja-mt-ilp64-auto-shared-check" + } + ] + } + ] +} diff --git a/build/cmake/subdir_helper_functions.cmake b/build/cmake/subdir_helper_functions.cmake index ad41a3001c..8d422f568c 100644 --- a/build/cmake/subdir_helper_functions.cmake +++ b/build/cmake/subdir_helper_functions.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Create a list of keywords for files that need to be ignored by the system. file(READ ${CMAKE_SOURCE_DIR}/build/gen-make-frags/ignore_list IGNORE_LIST) diff --git a/build/config.mk.in b/build/config.mk.in index eddb69f705..cf6be8255e 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -94,6 +94,7 @@ CC := @CC@ GCC_OT_4_9_0 := @gcc_older_than_4_9_0@ GCC_OT_6_1_0 := @gcc_older_than_6_1_0@ GCC_OT_9_1_0 := @gcc_older_than_9_1_0@ +GCC_OT_11_2_0 := @gcc_older_than_11_2_0@ # The C++ compiler. NOTE: A C++ is typically not needed. CXX := @CXX@ @@ -172,6 +173,9 @@ ARG_MAX_HACK := @enable_arg_max_hack@ MK_ENABLE_STATIC := @enable_static@ MK_ENABLE_SHARED := @enable_shared@ +# Whether to use an install_name based on @rpath. +MK_ENABLE_RPATH := @enable_rpath@ + # Whether to export all symbols within the shared library, even those symbols # that are considered to be for internal use only. EXPORT_SHARED := @export_shared@ diff --git a/build/flatten-headers.py b/build/flatten-headers.py index 563725a7e9..d23dfc4482 100755 --- a/build/flatten-headers.py +++ b/build/flatten-headers.py @@ -398,8 +398,14 @@ def main(): % output_name, verbose_flag ) sys.exit() - # Print usage if we don't have exactly four arguments. - if len( args ) != 4: + # Print usage if we don't have minimum four arguments. + if len( args ) < 4: + print_usage() + sys.exit() + elif "||" in args[:4] or "'(set', 'FAIL_LINE=3&', 'goto', ':ABORT)'" in args[:4]: + print('\n==============================================') + print(sys.argv) + print('==============================================\n') print_usage() sys.exit() diff --git a/common.mk b/common.mk index 7f200545ed..a3e21c6267 100644 --- a/common.mk +++ b/common.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -71,6 +71,7 @@ $(eval $(call store-var-for,CWARNFLAGS, $(1))) $(eval $(call store-var-for,CDBGFLAGS, $(1))) $(eval $(call store-var-for,COPTFLAGS, $(1))) $(eval $(call store-var-for,CKOPTFLAGS, $(1))) +$(eval $(call store-var-for,CKLPOPTFLAGS, $(1))) $(eval $(call store-var-for,CKVECFLAGS, $(1))) $(eval $(call store-var-for,CROPTFLAGS, $(1))) $(eval $(call store-var-for,CRVECFLAGS, $(1))) @@ -159,6 +160,15 @@ get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ $(BUILD_SYMFLAGS) \ ) +get-kernel-lpgemm-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ + $(call load-var-for,CKLPOPTFLAGS,$(1)) \ + $(call load-var-for,CKVECFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(COMPSIMDFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) + # When compiling addons, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various # sub-configurations. @@ -224,6 +234,7 @@ get-config-text-for = "('$(1)' CFLAGS for config code)" get-frame-text-for = "('$(1)' CFLAGS for framework code)" get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" get-kernel-text-for = "('$(1)' CFLAGS for kernels)" +get-kernel-lpgemm-text-for= "('$(1)' CFLAGS for lpgemm kernels)" get-addon-c99text-for = "('$(1)' CFLAGS for addons)" get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" get-addon-kernel-text-for = "('$(1)' CFLAGS for addon kernels)" @@ -557,11 +568,19 @@ LIBM := -lm endif LIBMEMKIND := -lmemkind +# Linking standard c++ library for aocl_gemm addon. +STDCXX := +ifeq ($(GCC_OT_11_2_0),yes) + ifeq ($(filter aocl_gemm, $(ADDON_LIST)), aocl_gemm) + STDCXX := -lstdc++ + endif +endif + # Default linker flags. # NOTE: -lpthread is needed unconditionally because BLIS uses pthread_once() # to initialize itself in a thread-safe manner. The one exception to this # rule: if --disable-system is given at configure-time, LIBPTHREAD is empty. -LDFLAGS := $(LDFLAGS_PRESET) $(LIBM) $(LIBPTHREAD) +LDFLAGS := $(LDFLAGS_PRESET) $(LIBM) $(LIBPTHREAD) $(STDCXX) # Add libmemkind to the link-time flags, if it was enabled at configure-time. ifeq ($(MK_ENABLE_MEMKIND),yes) @@ -583,7 +602,11 @@ endif ifeq ($(OS_NAME),Darwin) # OS X shared library link flags. SOFLAGS := -dynamiclib +ifeq ($(MK_ENABLE_RPATH),yes) +SOFLAGS += -Wl,-install_name,@rpath/$(LIBBLIS_SONAME) +else SOFLAGS += -Wl,-install_name,$(libdir)/$(LIBBLIS_SONAME) +endif else SOFLAGS := -shared ifeq ($(IS_WIN),yes) @@ -619,7 +642,17 @@ ifeq ($(MK_ENABLE_SHARED),yes) LIBBLIS_LINK := $(LIBBLIS_SO_PATH) ifeq ($(IS_WIN),no) # For Linux and OS X: set rpath property of shared object. - LDFLAGS += -Wl,-rpath,$(BASE_LIB_PATH) + ifeq ($(OS_NAME),Darwin) + # rpath for any executables generated in the top level directory + LDFLAGS += -Wl,-rpath,@executable_path/$(BASE_LIB_PATH) + # rpath for BLAS tests and test_libblis.x + LDFLAGS += -Wl,-rpath,@executable_path/../../../$(BASE_LIB_PATH) + else + # rpath for any executables generated in the top level directory + LDFLAGS += -Wl,-rpath,'$$ORIGIN/$(BASE_LIB_PATH)' + # rpath for BLAS tests and test_libblis.x + LDFLAGS += -Wl,-rpath,'$$ORIGIN/../../../../$(BASE_LIB_PATH)' + endif endif endif # On windows, use the shared library even if static is created. diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index b23fb85a4e..9fa3071ab1 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Writing a function that will be used to generate the required object # libraries for the required configs. @@ -52,10 +84,8 @@ function(generate_config_targets config_target) # in get-noopt-cflags-for target_compile_options(${config_target}_CONFIG PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${config_target}_CONFIG PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_CONFIG PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${config_target}_CONFIG flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${config_target}_CONFIG PROPERTIES FOLDER object-libs-targets) @@ -108,10 +138,8 @@ function(generate_config_targets config_target) # in get-noopt-cflags-for target_compile_options(${config_target}_REFINIT PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${config_target}_REFINIT PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_REFINIT PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${config_target}_REFINIT flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${config_target}_REFINIT PROPERTIES FOLDER object-libs-targets) @@ -172,10 +200,8 @@ function(generate_config_targets config_target) # in get-noopt-cflags-for target_compile_options(${config_target}_REFKERN PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${config_target}_REFKERN PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_REFKERN PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${config_target}_REFKERN flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${config_target}_REFKERN PROPERTIES FOLDER object-libs-targets) diff --git a/config/a64fx/bli_cntx_init_a64fx.c b/config/a64fx/bli_cntx_init_a64fx.c index 5061570f80..5132b2824c 100644 --- a/config/a64fx/bli_cntx_init_a64fx.c +++ b/config/a64fx/bli_cntx_init_a64fx.c @@ -49,29 +49,32 @@ void bli_cntx_init_a64fx( cntx_t* cntx ) // their storage preferences. bli_cntx_set_l3_nat_ukrs ( - 2, - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + 4, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_armsve_asm_2vx10_unindexed, FALSE, cntx ); // Set SVE-512 packing routine. bli_cntx_set_packm_kers ( - 3, + 2, BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, - BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_12xk, + // 12xk is not used and disabled for GCC 8-9 compatibility. + // BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_int_12xk, BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, cntx ); // Initialize level-3 blocksize objects with architecture-specific values. // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 10, 10, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 256, 128, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 2048, 2048, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 23040, 26880, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 16, 8 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 10, 10, 10, 10 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 256, 128, 192, 96 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 2048, 2048, 1536, 1536 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 23040, 26880, 11520, 11760 ); // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. diff --git a/config/amdzen/make_defs.cmake b/config/amdzen/make_defs.cmake index ac7d1b506e..89deb14b71 100644 --- a/config/amdzen/make_defs.cmake +++ b/config/amdzen/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # For architecture independent files we still need to define # the required flags. diff --git a/config/arm64/make_defs.mk b/config/arm64/make_defs.mk index e7e1977995..fc1a062e68 100644 --- a/config/arm64/make_defs.mk +++ b/config/arm64/make_defs.mk @@ -65,7 +65,11 @@ CKOPTFLAGS := $(COPTFLAGS) -O3 ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -march=armv8-a else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -march=armv8-a +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. diff --git a/config/armsve/bli_armsve_config_utils.c b/config/armsve/bli_armsve_config_utils.c index fdddeebabe..70501e39db 100644 --- a/config/armsve/bli_armsve_config_utils.c +++ b/config/armsve/bli_armsve_config_utils.c @@ -89,4 +89,6 @@ void PASTEMAC(ch, _blksz_armsve) (dim_t *m_r_, dim_t *n_r_, \ EXPANDMAC_BLKSZ_ARMSVE( s, 4 ) EXPANDMAC_BLKSZ_ARMSVE( d, 8 ) +EXPANDMAC_BLKSZ_ARMSVE( c, 8 ) +EXPANDMAC_BLKSZ_ARMSVE( z, 16 ) diff --git a/config/armsve/bli_armsve_config_utils.h b/config/armsve/bli_armsve_config_utils.h index 07aa9ba7d2..87bba73ed5 100644 --- a/config/armsve/bli_armsve_config_utils.h +++ b/config/armsve/bli_armsve_config_utils.h @@ -39,4 +39,6 @@ dim_t bli_vl_bits_armsve(void); void bli_s_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); void bli_d_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_c_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_z_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); diff --git a/config/armsve/bli_cntx_init_armsve.c b/config/armsve/bli_cntx_init_armsve.c index 434979f915..fafed2229b 100644 --- a/config/armsve/bli_cntx_init_armsve.c +++ b/config/armsve/bli_cntx_init_armsve.c @@ -50,17 +50,23 @@ void bli_cntx_init_armsve( cntx_t* cntx ) // Block size. dim_t m_r_s, n_r_s, k_c_s, m_c_s, n_c_s; dim_t m_r_d, n_r_d, k_c_d, m_c_d, n_c_d; + dim_t m_r_c, n_r_c, k_c_c, m_c_c, n_c_c; + dim_t m_r_z, n_r_z, k_c_z, m_c_z, n_c_z; bli_s_blksz_armsve(&m_r_s, &n_r_s, &k_c_s, &m_c_s, &n_c_s); bli_d_blksz_armsve(&m_r_d, &n_r_d, &k_c_d, &m_c_d, &n_c_d); + bli_c_blksz_armsve(&m_r_c, &n_r_c, &k_c_c, &m_c_c, &n_c_c); + bli_z_blksz_armsve(&m_r_z, &n_r_z, &k_c_z, &m_c_z, &n_c_z); // Update the context with optimized native gemm micro-kernels and // their storage preferences. bli_cntx_set_l3_nat_ukrs ( - 2, + 4, // These are vector-length agnostic kernels. Yet knowing mr is required at runtime. - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_armsve_asm_2vx10_unindexed, FALSE, cntx ); @@ -68,9 +74,8 @@ void bli_cntx_init_armsve( cntx_t* cntx ) if (m_r_d==16) bli_cntx_set_packm_kers ( - 3, + 2, BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, - BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_12xk, BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, cntx ); @@ -78,17 +83,17 @@ void bli_cntx_init_armsve( cntx_t* cntx ) bli_cntx_set_packm_kers ( 1, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armsve256_asm_8xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armsve256_int_8xk, cntx ); // Initialize level-3 blocksize objects with architecture-specific values. // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], m_r_s, m_r_d, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], n_r_s, n_r_d, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], m_c_s, m_c_d, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], k_c_s, k_c_d, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], n_c_s, n_c_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], m_r_s, m_r_d, m_r_c, m_r_z ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], n_r_s, n_r_d, n_r_c, n_r_z ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], m_c_s, m_c_d, m_c_c, m_c_z ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], k_c_s, k_c_d, k_c_c, k_c_z ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], n_c_s, n_c_d, n_c_c, n_c_z ); // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. diff --git a/config/cortexa53/make_defs.mk b/config/cortexa53/make_defs.mk index 2745e6dc5c..b5b2220a67 100644 --- a/config/cortexa53/make_defs.mk +++ b/config/cortexa53/make_defs.mk @@ -65,7 +65,11 @@ CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mcpu=cortex-a53 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=cortex-a53 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. diff --git a/config/cortexa57/make_defs.mk b/config/cortexa57/make_defs.mk index 2fcb955cc4..83565b8a79 100644 --- a/config/cortexa57/make_defs.mk +++ b/config/cortexa57/make_defs.mk @@ -65,7 +65,11 @@ CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mcpu=cortex-a57 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=cortex-a57 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. diff --git a/config/firestorm/bli_cntx_init_firestorm.c b/config/firestorm/bli_cntx_init_firestorm.c new file mode 100644 index 0000000000..a15ce03448 --- /dev/null +++ b/config/firestorm/bli_cntx_init_firestorm.c @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_cntx_init_firestorm( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_firestorm_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 2, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_8x12, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8, FALSE, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 4, + BLIS_PACKM_8XK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_8xk, + BLIS_PACKM_12XK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_12xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_8xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 8, 6, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 120, 252, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 640, 3072, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 8192, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 99, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 99, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 99, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_armv8a_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_armv8a_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 6, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 240, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 1024, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 3072, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} + diff --git a/config/firestorm/bli_family_firestorm.h b/config/firestorm/bli_family_firestorm.h new file mode 100644 index 0000000000..4a60ed2f2b --- /dev/null +++ b/config/firestorm/bli_family_firestorm.h @@ -0,0 +1,76 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 16 + + +#if 0 +// -- LEVEL-3 MICRO-KERNEL CONSTANTS ------------------------------------------- + +#define BLIS_SGEMM_UKERNEL bli_sgemm_opt_8x12 +#define BLIS_DEFAULT_MR_S 8 +#define BLIS_DEFAULT_NR_S 12 +#define BLIS_DEFAULT_MC_S 120 //1536 //336 //416 // 1280 //160 // 160 // 160 //2048 //336 +#define BLIS_DEFAULT_KC_S 640 //1536 //336 //704 //1280 //672 //528 // 856 //2048 //528 +#define BLIS_DEFAULT_NC_S 3072 + +#define BLIS_DGEMM_UKERNEL bli_dgemm_opt_6x8 +#define BLIS_DEFAULT_MR_D 6 +#define BLIS_DEFAULT_NR_D 8 +#define BLIS_DEFAULT_MC_D 120 //1536 //160 //80 //176 +#define BLIS_DEFAULT_KC_D 240 //1536 //304 //336 //368 +#define BLIS_DEFAULT_NC_D 3072 + +#define BLIS_DEFAULT_MR_C 8 +#define BLIS_DEFAULT_NR_C 4 +#define BLIS_DEFAULT_MC_C 64 +#define BLIS_DEFAULT_KC_C 128 +#define BLIS_DEFAULT_NC_C 4096 + +#define BLIS_DEFAULT_MR_Z 8 +#define BLIS_DEFAULT_NR_Z 4 +#define BLIS_DEFAULT_MC_Z 64 +#define BLIS_DEFAULT_KC_Z 128 +#define BLIS_DEFAULT_NC_Z 4096 +#endif + + +//#endif + diff --git a/config/firestorm/make_defs.mk b/config/firestorm/make_defs.mk new file mode 100644 index 0000000000..dc4286e6a8 --- /dev/null +++ b/config/firestorm/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := firestorm +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O2 -march=armv8-a +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize +CKVECFLAGS := -march=armv8-a + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/generic/make_defs.cmake b/config/generic/make_defs.cmake index d99d08e691..c483904c46 100644 --- a/config/generic/make_defs.cmake +++ b/config/generic/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] if(NOT WIN32) if(NOT (DEBUG_TYPE STREQUAL "off")) @@ -19,11 +51,11 @@ else() set(CKOPTFLAGS ${COPTFLAGS} -O3) endif() -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") # Placeholder in case we want to add gcc-specific flags. -elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "icc") +elseif("${CMAKE_C_COMPILER_ID}" STREQUAL "icc") # Placeholder in case we want to add icc-specific flags. -elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +elseif("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") # Placeholder in case we want to add clang-specific flags. else() message(FATAL_ERROR "gcc, icc, or clang is required for this configuration.") @@ -31,9 +63,9 @@ endif() # Flags specific to reference kernels. set(CROPTFLAGS ${CKOPTFLAGS}) -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") set(CRVECFLAGS ${CKVECFLAGS}) -elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +elseif("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") set(CRVECFLAGS ${CKVECFLAGS}) else() set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen/amd_config.cmake b/config/zen/amd_config.cmake index df3284d8fb..70fb5b23e4 100644 --- a/config/zen/amd_config.cmake +++ b/config/zen/amd_config.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] if(NOT WIN32) if(NOT (DEBUG_TYPE STREQUAL "off")) @@ -24,11 +56,11 @@ endif() if(MSVC) set(CKVECFLAGS -mavx2 -mfma -mno-fma4 -mno-tbm -mno-xop -mno-lwp) -elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +elseif("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") set(CKVECFLAGS -mavx2 -mfpmath=sse -mfma) -elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") +elseif("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") set(CKVECFLAGS -mavx2 -mfpmath=sse -mfma -mno-fma4 -mno-tbm -mno-xop -mno-lwp) - execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + execute_process(COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") string(REGEX MATCHALL "(AOCC.LLVM)" CLANG_STRING "${CLANG_VERSION_STRING}") if("${CLANG_STRING}" MATCHES "(AOCC.LLVM)") @@ -40,9 +72,9 @@ endif() # Flags specific to reference kernels. set(CROPTFLAGS ${CKOPTFLAGS}) -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") set(CRVECFLAGS ${CKVECFLAGS} -funsafe-math-optimizations -ffp-contract=fast) -elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") +elseif("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") set(CRVECFLAGS ${CKVECFLAGS} -funsafe-math-optimizations -ffp-contract=fast) else() set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen/amd_config.mk b/config/zen/amd_config.mk index 5ca32b268a..10c3e09491 100644 --- a/config/zen/amd_config.mk +++ b/config/zen/amd_config.mk @@ -1,10 +1,10 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index d88ea7577e..376f7d87e8 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -102,7 +102,13 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 29, + 40, + // addv + BLIS_ADDV_KER, BLIS_FLOAT, bli_saddv_zen_int, + BLIS_ADDV_KER, BLIS_DOUBLE, bli_daddv_zen_int, + BLIS_ADDV_KER, BLIS_SCOMPLEX, bli_caddv_zen_int, + BLIS_ADDV_KER, BLIS_DCOMPLEX, bli_zaddv_zen_int, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, @@ -134,6 +140,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int, BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, // swapv @@ -143,13 +150,19 @@ void bli_cntx_init_zen( cntx_t* cntx ) // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_SCOMPLEX, bli_ccopyv_zen_int, BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, // setv BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int, + BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int, // scal2v + BLIS_SCAL2V_KER, BLIS_FLOAT, bli_sscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_DOUBLE, bli_dscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_SCOMPLEX, bli_cscal2v_zen_int, BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); diff --git a/config/zen/make_defs.cmake b/config/zen/make_defs.cmake index 682434bf52..449f441805 100644 --- a/config/zen/make_defs.cmake +++ b/config/zen/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Include file containing common flags for all AMD architectures include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) @@ -14,6 +46,9 @@ if(NOT WIN32) endif() endif() +# Flags specific to LPGEMM kernels. +set(CKLPOPTFLAGS "") + # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. @@ -23,16 +58,20 @@ else() set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) endif() -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") list(APPEND CKVECFLAGS -march=znver1) - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) endif() endif() +if("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") + list(APPEND CKVECFLAGS -march=znver1) +endif() # clang + # Flags specific to reference kernels. set(CROPTFLAGS ${CKOPTFLAGS}) -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") set(CRVECFLAGS ${CKVECFLAGS}) else() set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index 4e8896bfb2..ef8a21cff9 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -78,16 +78,23 @@ endif # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +# Additional flag which is required for lpgemm kernels +CKLPOPTFLAGS := + ifeq ($(CC_VENDOR),gcc) CKVECFLAGS += -march=znver1 GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse endif endif# gcc +ifeq ($(CC_VENDOR),clang) + CKVECFLAGS += -march=znver1 +endif # clang + # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index c7d8137329..a55e7cdbe2 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -117,7 +117,13 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 29, + 40, + // addv + BLIS_ADDV_KER, BLIS_FLOAT, bli_saddv_zen_int, + BLIS_ADDV_KER, BLIS_DOUBLE, bli_daddv_zen_int, + BLIS_ADDV_KER, BLIS_SCOMPLEX, bli_caddv_zen_int, + BLIS_ADDV_KER, BLIS_DCOMPLEX, bli_zaddv_zen_int, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, @@ -149,6 +155,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int, BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, // swapv @@ -158,13 +165,19 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_SCOMPLEX, bli_ccopyv_zen_int, BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, // setv BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int, + BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int, // scal2v + BLIS_SCAL2V_KER, BLIS_FLOAT, bli_sscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_DOUBLE, bli_dscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_SCOMPLEX, bli_cscal2v_zen_int, BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); diff --git a/config/zen2/make_defs.cmake b/config/zen2/make_defs.cmake index 2296a3d2c2..dfd5624e66 100644 --- a/config/zen2/make_defs.cmake +++ b/config/zen2/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Include file containing common flags for all AMD architectures include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) @@ -14,6 +46,9 @@ if(NOT WIN32) endif() endif() +# Flags specific to LPGEMM kernels. +set(CKLPOPTFLAGS "") + # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. @@ -24,11 +59,11 @@ else() endif() # gcc or clang version must be at least 4.0 -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # gcc 9.0 or later list(APPEND CKVECFLAGS -march=znver2) - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) else() # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. @@ -37,7 +72,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") endif() endif() # gcc -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") # AOCC clang has various formats for the version line # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) @@ -49,7 +84,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # For our purpose we just want to know if it version 2x or 3x or 4x # But also set these in case we are using upstream LLVM clang - execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + execute_process(COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") @@ -63,7 +98,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") # AOCC version 2x we will enable znver2 list(APPEND CKVECFLAGS -march=znver2) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # LLVM clang 9.0 or later list(APPEND CKVECFLAGS -march=znver2) else() diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index b54ebda881..995fb8c644 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -74,6 +74,8 @@ endif # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +# Additional flag which is required for lpgemm kernels +CKLPOPTFLAGS := # gcc or clang version must be at least 4.0 ifeq ($(CC_VENDOR),gcc) @@ -82,7 +84,7 @@ ifeq ($(CC_VENDOR),gcc) ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) # gcc 9.0 or later CKVECFLAGS += -march=znver2 - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index b5b99eb609..d356c2eb9f 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -120,7 +120,13 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 29, + 40, + // addv + BLIS_ADDV_KER, BLIS_FLOAT, bli_saddv_zen_int, + BLIS_ADDV_KER, BLIS_DOUBLE, bli_daddv_zen_int, + BLIS_ADDV_KER, BLIS_SCOMPLEX, bli_caddv_zen_int, + BLIS_ADDV_KER, BLIS_DCOMPLEX, bli_zaddv_zen_int, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, @@ -152,6 +158,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int, BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, // swapv @@ -161,13 +168,19 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_SCOMPLEX, bli_ccopyv_zen_int, BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, // setv BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int, + BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int, // scal2v + BLIS_SCAL2V_KER, BLIS_FLOAT, bli_sscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_DOUBLE, bli_dscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_SCOMPLEX, bli_cscal2v_zen_int, BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); diff --git a/config/zen3/make_defs.cmake b/config/zen3/make_defs.cmake index 077deb68c3..adb808ce42 100644 --- a/config/zen3/make_defs.cmake +++ b/config/zen3/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # FLAGS that are specific to the 'zen3' architecture are added here. # FLAGS that are common for all the AMD architectures are present in @@ -20,6 +52,9 @@ if(NOT WIN32) endif() endif() +# Flags specific to LPGEMM kernels. +set(CKLPOPTFLAGS "") + # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. @@ -29,20 +64,20 @@ else() set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) endif() -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) # gcc 11.0 or later list(APPEND CKVECFLAGS -march=znver3) - # Update CKOPTFLAGS for gcc to use O3 optimization without + # Update CKLPOPTFLAGS for gcc to use O3 optimization without # -ftree-pre and -ftree-partial-pre flag. These flag results # in suboptimal code generation for instrinsic based kernels. # The -ftree-loop-vectorize results in inefficient code gen # for amd optimized l1 kernels based on instrinsics. - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # gcc 9.0 or later list(APPEND CKVECFLAGS -march=znver2) - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) else() # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. @@ -51,7 +86,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") endif() endif() -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") # AOCC clang has various formats for the version line # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) @@ -63,7 +98,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # For our purpose we just want to know if it version 2x or 3x or 4x # But also set these in case we are using upstream LLVM clang - execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + execute_process(COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") @@ -77,7 +112,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") # AOCC version 2x we will enable znver2 list(APPEND CKVECFLAGS -march=znver2) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # LLVM clang 9.0 or later list(APPEND CKVECFLAGS -march=znver2) else() diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index 727be9d603..d1943f6ac9 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -74,6 +74,8 @@ endif # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +# Additional flag which is required for lpgemm kernels +CKLPOPTFLAGS := # gcc or clang version must be at least 4.0 ifeq ($(CC_VENDOR),gcc) @@ -87,11 +89,11 @@ ifeq ($(CC_VENDOR),gcc) # in suboptimal code generation for instrinsic based kernels. # The -ftree-loop-vectorize results in inefficient code gen # for amd optimized l1 kernels based on instrinsics. - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) # gcc 9.0 or later CKVECFLAGS += -march=znver2 - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 8a79ff8a1f..e13ebf7590 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -44,26 +44,22 @@ bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 128, 144, 60 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ - 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \ bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4002, 4080, 2004 ); \ \ bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ - + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); #define BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs) \ /* s d c z */ \ bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 64, 144, 60 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ - 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \ bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 3600, 4080, 2004 ); \ \ bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ - + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); void bli_cntx_init_zen4( cntx_t* cntx ) { @@ -156,14 +152,17 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 28, + 32, + // addv + BLIS_ADDV_KER, BLIS_DOUBLE, bli_daddv_zen_int_avx512, + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int_avx512, BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, @@ -171,13 +170,13 @@ void bli_cntx_init_zen4( cntx_t* cntx ) BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int_avx512, BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int_avx512, BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, - BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int_avx512, // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int_avx512, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int_avx512, BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, - BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int_avx512, // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, @@ -187,22 +186,25 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int_avx512, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int_avx512, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int_avx512, // swapv BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, // copyv - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen4_asm_avx512, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen4_asm_avx512, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen4_asm_avx512, // setv - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int_avx512, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int_avx512, + BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int_avx512, // scal2v + BLIS_SCAL2V_KER, BLIS_DOUBLE, bli_dscal2v_zen_int_avx512, BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -359,11 +361,10 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // triangular objects with architecture-specific values. // // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 3, 4 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); // Update the context with the current architecture's register and cache @@ -392,14 +393,14 @@ void bli_cntx_init_zen4( cntx_t* cntx ) BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, @@ -408,14 +409,14 @@ void bli_cntx_init_zen4( cntx_t* cntx ) BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, cntx ); } diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index bacf8b62a4..67dedef858 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -52,11 +52,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -// -- SIMD config -------------------------------------------------------- - -#define BLIS_SIMD_ALIGN_SIZE 64 - -#define BLIS_SIMD_SIZE 64 -#define BLIS_SIMD_NUM_REGISTERS 32 - #endif diff --git a/config/zen4/make_defs.cmake b/config/zen4/make_defs.cmake index e5ce4401b7..78106bb7e6 100644 --- a/config/zen4/make_defs.cmake +++ b/config/zen4/make_defs.cmake @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # FLAGS that are specific to the 'zen4' architecture are added here. # FLAGS that are common for all the AMD architectures are present in @@ -18,6 +50,9 @@ if(NOT WIN32) endif() endif() +# Flags specific to LPGEMM kernels. +set(CKLPOPTFLAGS "") + # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. @@ -27,40 +62,44 @@ else() set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) endif() -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) # gcc 13.0 or later list(APPEND CKVECFLAGS -march=znver4) list(APPEND CRVECFLAGS -march=znver4) - # Update CKOPTFLAGS for gcc to use O3 optimization without + # Update CKLPOPTFLAGS for gcc to use O3 optimization without # -ftree-pre and -ftree-partial-pre flag. These flag results # in suboptimal code generation for instrinsic based kernels. # The -ftree-loop-vectorize results in inefficient code gen # for amd optimized l1 kernels based on instrinsics. - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) # gcc 11.0 or later - list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16) + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi) list(APPEND CRVECFLAGS -march=znver3) - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # gcc 9.0 or later - list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) list(APPEND CRVECFLAGS -march=znver2) - list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.0.0) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 8.0.0) # gcc 8.0 or later - list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) + list(APPEND CRVECFLAGS -march=znver1) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.0.0) + # gcc 7.0 or later + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl) list(APPEND CRVECFLAGS -march=znver1) else() - # If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 + # If gcc is older than 7.0.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. list(APPEND CKVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) list(APPEND CRVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) endif() endif() # gcc -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +if("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") # AOCC clang has various formats for the version line # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) @@ -72,37 +111,40 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # For our purpose we just want to know if it version 2x or 3x or 4x # But also set these in case we are using upstream LLVM clang - execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + execute_process(COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") + if(NOT WIN32) + set(alignloops "-falign-loops=64") + endif() if("${CLANG_STRING}" MATCHES "AOCC_4") # AOCC version 4x we will enable znver4 - list(APPEND CKVECFLAGS -march=znver4 -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver4 ${alignloops}) list(APPEND CRVECFLAGS -march=znver4) elseif("${CLANG_STRING}" MATCHES "AOCC_3") # AOCC version 3x we will enable znver3 - list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) list(APPEND CRVECFLAGS -march=znver3) elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") # AOCC version 2x we will enable znver2 - list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) list(APPEND CRVECFLAGS -march=znver2) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 16.0.0) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 16.0.0) # LLVM clang 16.0 or later - list(APPEND CKVECFLAGS -march=znver4 -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver4 ${alignloops}) list(APPEND CRVECFLAGS -march=znver4) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) # LLVM clang 13.0 or later - list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) list(APPEND CRVECFLAGS -march=znver3) - elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) # LLVM clang 9.0 or later - list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) list(APPEND CRVECFLAGS -march=znver2) else() - list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -falign-loops=64) + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi ${alignloops}) list(APPEND CRVECFLAGS -march=znver1) endif() endif() diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index bca80fcc9f..95008d8b6e 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -73,6 +73,8 @@ endif # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +# Additional flag which is required for lpgemm kernels +CKLPOPTFLAGS := # gcc or clang version must be at least 4.0 ifeq ($(CC_VENDOR),gcc) @@ -87,23 +89,27 @@ ifeq ($(CC_VENDOR),gcc) # in suboptimal code generation for instrinsic based kernels. # The -ftree-loop-vectorize results in inefficient code gen # for amd optimized l1 kernels based on instrinsics. - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize else ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) # gcc 11.0 or later - CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi CRVECFLAGS += -march=znver3 - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) # gcc 9.0 or later - CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi CRVECFLAGS += -march=znver2 - CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize else ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) # gcc 8.0 or later - CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi + CRVECFLAGS += -march=znver1 + else ifeq ($(shell test $(GCC_VERSION) -ge 7; echo $$?),0) + # gcc 7.0 or later + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl CRVECFLAGS += -march=znver1 else - # If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 + # If gcc is older than 7.0.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store @@ -132,11 +138,11 @@ ifeq ($(CC_VENDOR),clang) CRVECFLAGS += -march=znver4 else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) # AOCC version 3x we will enable znver3 - CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 CRVECFLAGS += -march=znver3 else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) # AOCC version 2x we will enable znver2 - CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi CRVECFLAGS += -march=znver2 else ifeq ($(shell test $(CC_MAJOR) -ge 16; echo $$?),0) # LLVM clang 16.0 or later @@ -144,14 +150,14 @@ ifeq ($(CC_VENDOR),clang) CRVECFLAGS += -march=znver4 else ifeq ($(shell test $(CC_MAJOR) -ge 13; echo $$?),0) # LLVM clang 13.0 or later - CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 CRVECFLAGS += -march=znver3 else ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) # LLVM clang 9.0 or later - CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 CRVECFLAGS += -march=znver2 else - CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -falign-loops=64 + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi -falign-loops=64 CRVECFLAGS += -march=znver1 endif endif # clang diff --git a/config/zen5/bli_cntx_init_zen5.c b/config/zen5/bli_cntx_init_zen5.c new file mode 100644 index 0000000000..8e0cafcbea --- /dev/null +++ b/config/zen5/bli_cntx_init_zen5.c @@ -0,0 +1,423 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +/* + * List of default block sizes for zen4. + * Converted it to macro as this list is used at multiple places in this file. + */ + +/* Blocksizes for double(d) datetype are tuned for Turin, rest are copied from Genoa */ +#define BLI_CNTX_DEFAULT_BLKSZ_LIST_TURIN(blkszs) \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 8, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 24, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 120, 144, 60 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 2016, 4080, 2004 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + +/* Blocksizes for double(d) datetype are tuned for Turin, rest are copied from Bergamo */ +#define BLI_CNTX_DEFAULT_BLKSZ_LIST_TURIN_DENSE(blkszs) \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 8, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 24, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 120, 144, 60 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 2016, 4080, 2004 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + +void bli_cntx_init_zen5( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen5_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 13, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_skx_asm_32x12_l2, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_avx512_asm_8x24, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + /*bli_zgemm_zen4_asm_12x4 is a column preferred kernel*/ + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_12x4, FALSE, + + // Different GEMM kernels are used for TRSM for zen4 architecture + BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE, + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_4x12, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen4_asm_8x24, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_l_zen4_asm_4x12, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen4_asm_8x24, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_u_zen4_asm_4x12, TRUE, + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 3, + // GEMM + BLIS_GEMM, bli_cntx_gemmsup_thresh_is_met_zen5, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 11, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_8xk, + BLIS_PACKM_24XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_24xk, + BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_32xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_12XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_12xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_4xk, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 9, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 32, + // addv + BLIS_ADDV_KER, BLIS_DOUBLE, bli_daddv_zen_int_avx512, + + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int_avx512, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int_avx512, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int_avx512, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int_avx512, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int_avx512, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int_avx512, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int_avx512, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int_avx512, + + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen4_asm_avx512, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen4_asm_avx512, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen4_asm_avx512, + + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int_avx512, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int_avx512, + BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int_avx512, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DOUBLE, bli_dscal2v_zen_int_avx512, + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // + // These are reference block sizes and may be overridden based on + // number of threads used at runtime. + + if ( bli_init_model_query_id() == BLIS_MODEL_TURIN_DENSE ) + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_TURIN_DENSE(blkszs); + } + else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_TURIN + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_TURIN(blkszs); + } + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize TRSM blocksize objects with architecture-specific values. + // Using different cache block sizes for TRSM instead of common level-3 block sizes. + // Tuning is done for double-precision only. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 8, 3, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 24, 8, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 120, 144, 40 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 512 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4008, 4080, 2004 ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 682, 1000, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 1000, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64m_avx512, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64n_avx512, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 3, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 384, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4032, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // Initialize level-3 sup blocksize objects for operations dealing with + // triangular objects with architecture-specific values. + // + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 24, 3, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_l3_sup_tri_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + bli_cntx_set_l3_sup_tri_kers + ( + 30, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen5_asm_24x8m, FALSE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen4_asm_4x4m, TRUE, + cntx + ); +} diff --git a/config/zen5/bli_family_zen5.h b/config/zen5/bli_family_zen5.h new file mode 100644 index 0000000000..25bd14c42e --- /dev/null +++ b/config/zen5/bli_family_zen5.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_FAMILY_ZEN5_ +#define BLI_FAMILY_ZEN5_ + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not parallelized. +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 + +#endif diff --git a/config/zen5/make_defs.cmake b/config/zen5/make_defs.cmake new file mode 100644 index 0000000000..9f6d4476af --- /dev/null +++ b/config/zen5/make_defs.cmake @@ -0,0 +1,168 @@ +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] + +# FLAGS that are specific to the 'zen5' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Include file containing common flags for all AMD architectures +include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to LPGEMM kernels. +set(CKLPOPTFLAGS "") + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS} /Oy) +else() + set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) +endif() + +if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 14.0.0) + # gcc 14.0 or later + list(APPEND CKVECFLAGS -march=znver5) + list(APPEND CRVECFLAGS -march=znver5) + # Update CKLPOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) + # gcc 13.0 or later + list(APPEND CKVECFLAGS -march=znver4) + list(APPEND CRVECFLAGS -march=znver4) + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) + # gcc 11.0 or later + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi) + list(APPEND CRVECFLAGS -march=znver3) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # gcc 9.0 or later + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) + list(APPEND CRVECFLAGS -march=znver2) + list(APPEND CKLPOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 8.0.0) + # gcc 8.0 or later + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) + list(APPEND CRVECFLAGS -march=znver1) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.0.0) + # gcc 7.0 or later + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl) + list(APPEND CRVECFLAGS -march=znver1) + else() + # If gcc is older than 7.0.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + list(APPEND CKVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + list(APPEND CRVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + endif() +endif() # gcc + +if("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") + # AOCC clang has various formats for the version line + + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + execute_process(COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") + string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") + string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") + + if(NOT WIN32) + set(alignloops "-falign-loops=64") + endif() + if("${CLANG_STRING}" MATCHES "AOCC_5") + # AOCC version 5x we will enable znver5 + list(APPEND CKVECFLAGS -march=znver5 ${alignloops}) + list(APPEND CRVECFLAGS -march=znver5) + elseif("${CLANG_STRING}" MATCHES "AOCC_4") + # AOCC version 4x we will enable znver4 + list(APPEND CKVECFLAGS -march=znver4 ${alignloops}) + list(APPEND CRVECFLAGS -march=znver4) + elseif("${CLANG_STRING}" MATCHES "AOCC_3") + # AOCC version 3x we will enable znver3 + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) + list(APPEND CRVECFLAGS -march=znver3) + elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") + # AOCC version 2x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi) + list(APPEND CRVECFLAGS -march=znver2) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 16.0.0) + # LLVM clang 16.0 or later + list(APPEND CKVECFLAGS -march=znver4 ${alignloops}) + list(APPEND CRVECFLAGS -march=znver4) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) + # LLVM clang 13.0 or later + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) + list(APPEND CRVECFLAGS -march=znver3) + elseif(CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # LLVM clang 9.0 or later + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi ${alignloops}) + list(APPEND CRVECFLAGS -march=znver2) + else() + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi ${alignloops}) + list(APPEND CRVECFLAGS -march=znver1) + endif() +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen5/make_defs.mk b/config/zen5/make_defs.mk new file mode 100644 index 0000000000..1830290373 --- /dev/null +++ b/config/zen5/make_defs.mk @@ -0,0 +1,186 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen5' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen5 +#CONFIGS_INCL += $(THIS_CONFIG) + +# Include file containing common flags for all AMD architectures +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. + +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) + CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) + COPTFLAGS := -O0 +else + COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +# Additional flag which is required for lpgemm kernels +CKLPOPTFLAGS := + +# gcc or clang version must be at least 4.0 +ifeq ($(CC_VENDOR),gcc) + GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + ifeq ($(shell test $(GCC_VERSION) -ge 14; echo $$?),0) + # gcc 14.0 or later + CKVECFLAGS += -march=znver5 + CRVECFLAGS += -march=znver5 + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 13; echo $$?),0) + # gcc 13.0 or later + CKVECFLAGS += -march=znver4 + CRVECFLAGS += -march=znver4 + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) + # gcc 11.0 or later + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi + CRVECFLAGS += -march=znver3 + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) + # gcc 9.0 or later + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi + CRVECFLAGS += -march=znver2 + CKLPOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) + # gcc 8.0 or later + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi + CRVECFLAGS += -march=znver1 + else ifeq ($(shell test $(GCC_VERSION) -ge 7; echo $$?),0) + # gcc 7.0 or later + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl + CRVECFLAGS += -march=znver1 + else + # If gcc is older than 7.0.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + endif +endif # gcc + +ifeq ($(CC_VENDOR),clang) + # AOCC clang has various formats for the version line + + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) + CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) + + ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_5')),1) + # AOCC version 5x we will enable znver5 + CKVECFLAGS += -march=znver5 -falign-loops=64 + CRVECFLAGS += -march=znver5 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) + # AOCC version 4x we will enable znver4 + CKVECFLAGS += -march=znver4 -falign-loops=64 + CRVECFLAGS += -march=znver4 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) + # AOCC version 3x we will enable znver3 + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 + CRVECFLAGS += -march=znver3 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) + # AOCC version 2x we will enable znver2 + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi + CRVECFLAGS += -march=znver2 + else ifeq ($(shell test $(CC_MAJOR) -ge 16; echo $$?),0) + # LLVM clang 16.0 or later + CKVECFLAGS += -march=znver4 -falign-loops=64 + CRVECFLAGS += -march=znver4 + else ifeq ($(shell test $(CC_MAJOR) -ge 13; echo $$?),0) + # LLVM clang 13.0 or later + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 + CRVECFLAGS += -march=znver3 + else ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) + # LLVM clang 9.0 or later + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mavx512vbmi -falign-loops=64 + CRVECFLAGS += -march=znver2 + else + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512vbmi -falign-loops=64 + CRVECFLAGS += -march=znver1 + endif +endif # clang + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +CRVECFLAGS := $(CKVECFLAGS) + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config_registry b/config_registry index cd0f9bbb68..37d310e4be 100644 --- a/config_registry +++ b/config_registry @@ -11,12 +11,9 @@ x86_64: intel64 amdzen amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic -amdzen: zen4 zen3 zen2 zen generic - -# NOTE: ARM families will remain disabled until runtime hardware detection -# logic is added to BLIS. -#arm64: cortexa57 generic -#arm32: cortexa15 cortexa9 generic +amdzen: zen5 zen4 zen3 zen2 zen generic +arm64: firestorm thunderx2 cortexa57 cortexa53 generic +arm32: cortexa15 cortexa9 generic # Intel architectures. skx: skx/skx/haswell/zen/zen4 @@ -26,6 +23,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. +zen5: zen5/zen5/zen4/skx/zen3/zen2/zen/haswell zen4: zen4/zen4/skx/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell @@ -38,6 +36,7 @@ bulldozer: bulldozer # ARM architectures. armsve: armsve/armsve a64fx: a64fx/armsve +firestorm: firestorm/armv8a thunderx2: thunderx2/armv8a cortexa57: cortexa57/armv8a cortexa53: cortexa53/armv8a diff --git a/configure b/configure index 92a34632bb..d961146193 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -144,6 +144,12 @@ print_usage() echo " library. If the shared library build is disabled, the" echo " static library build must remain enabled." echo " " + echo " --enable-rpath, --disable-rpath" + echo " " + echo " Enable (disabled by default) setting an install_name for" + echo " dynamic libraries on macOS which starts with @rpath rather" + echo " than the absolute install path." + echo " " echo " -e SYMBOLS, --export-shared[=SYMBOLS]" echo " " echo " Specify the subset of library symbols that are exported" @@ -723,13 +729,21 @@ read_registry_file() if [ "${mem}" != "${mems_mem}" ]; then #clist="${config_registry[$config]}" - clist=$(query_array "config_registry" ${config}) + clisttmp=$(query_array "config_registry" ${config}) # Replace the current config with its constituent config set, # canonicalize whitespace, and then remove duplicate config # set names, if they exist. Finally, update the config registry # with the new config list. - newclist=$(echo -e "${clist}" | sed -e "s/${mem}/${mems_mem}/g") + # NOTE: WE must use substitute_words() rather than a simple sed + # expression because we need to avoid matching partial strings. + # For example, if clist above contains "foo bar barsk" and we use + # sed to substitute "bee boo" as the members of "bar", the + # result would (incorrectly) be "foo bee boo bee boosk", + # which would then get reduced, via rm_duplicate_words(), to + # "foo bee boo boosk". + #newclist=$(echo -e "${clist}" | sed -e "s/${mem}/${mems_mem}/g") + newclist=$(substitute_words "${mem}" "${mems_mem}" "${clisttmp}") newclist=$(canonicalize_ws "${newclist}") newclist=$(rm_duplicate_words "${newclist}") @@ -812,6 +826,13 @@ read_registry_file() # canonicalize whitespace, and then remove duplicate kernel # set names, if they exist. Finally, update the kernel registry # with the new kernel list. + # NOTE: WE must use substitute_words() rather than a simple sed + # expression because we need to avoid matching partial strings. + # For example, if klist above contains "foo bar barsk" and we use + # sed to substitute "bee boo" as the members of "bar", the + # result would (incorrectly) be "foo bee boo bee boosk", + # which would then get reduced, via rm_duplicate_words(), to + # "foo bee boo boosk". #newklist=$(echo -e "${klisttmp}" | sed -e "s/${ker}/${kers_ker}/g") newklist=$(substitute_words "${ker}" "${kers_ker}" "${klisttmp}") newklist=$(canonicalize_ws "${newklist}") @@ -1558,6 +1579,8 @@ check_compiler() echo "${script_name}: checking for blacklisted configurations due to ${cc} ${cc_version}." + # Fixme: check on a64fx, neoverse, and others + # gcc if [ "x${cc_vendor}" = "xgcc" ]; then @@ -1719,7 +1742,7 @@ check_compiler_version_ranges() gcc_older_than_4_9_0='no' gcc_older_than_6_1_0='no' gcc_older_than_9_1_0='no' - + gcc_older_than_11_2_0='no' echo "${script_name}: checking ${cc} ${cc_version} against known consequential version ranges." # gcc @@ -1744,6 +1767,19 @@ check_compiler_version_ranges() echo "${script_name}: note: found ${cc} version older than 9.1." gcc_older_than_9_1_0='yes' fi + + # Check for gcc < 11.2.0 (ie: 11.2 or older). + if [ ${cc_major} -lt 11 ]; then + echo "${script_name}: note: found ${cc} version older than 11.2.0." + gcc_older_than_11_2_0='yes' + else + if [ ${cc_major} -eq 11 ]; then + if [ ${cc_minor} -lt 2 ]; then + echo "${script_name}: note: found ${cc} version older than 11.2.0." + gcc_older_than_11_2_0='yes' + fi + fi + fi fi # icc @@ -1877,7 +1913,7 @@ set_default_version() echo "${script_name}: determining default version string." # Use what's in the version file as-is. - version="AOCL-BLIS $(cat "${version_file}") Build $(date +%Y%m%d)" + version="AOCL-BLAS $(cat "${version_file}") Build $(date +%Y%m%d)" } @@ -2061,6 +2097,7 @@ main() enable_arg_max_hack='no' enable_static='yes' enable_shared='yes' + enable_rpath='no' export_shared='public' enable_pba_pools='yes' enable_sba_pools='yes' @@ -2194,6 +2231,12 @@ main() disable-shared) enable_shared='no' ;; + enable-rpath) + enable_rpath='yes' + ;; + disable-rpath) + enable_rpath='no' + ;; export-shared=*) export_shared=${OPTARG#*=} ;; @@ -2813,7 +2856,7 @@ main() # Based on the number of sub-configurations, set default value for disable_blis_arch_type # (if user hasn't set option). BLIS_ARCH_TYPE functionality only makes sense for use with - # processor families containing multiple sub-configurations, but user can force the + # processor families containing multiple sub-configurations, but user can force the # functionality to be enabled/disabled with --enable-blis-arch-type/--disable-blis-arch-type # configure options. if [ "x${disable_blis_arch_type}" = "xunset" ]; then @@ -3090,7 +3133,9 @@ main() echo "${script_name}: internal memory pools for small blocks are enabled." enable_sba_pools_01=1 else - echo "${script_name}: internal memory pools for small blocks are disabled." + #echo "${script_name}: internal memory pools for small blocks are disabled." + echo "${script_name}: *** disabling memory pools for small blocks is currently disabled, awaiting fixes to this functionality." + exit 1 enable_sba_pools_01=0 fi if [ "x${enable_mem_tracing}" = "xyes" ]; then @@ -3474,6 +3519,7 @@ main() | sed -e "s/@gcc_older_than_4_9_0@/${gcc_older_than_4_9_0}/g" \ | sed -e "s/@gcc_older_than_6_1_0@/${gcc_older_than_6_1_0}/g" \ | sed -e "s/@gcc_older_than_9_1_0@/${gcc_older_than_9_1_0}/g" \ + | sed -e "s/@gcc_older_than_11_2_0@/${gcc_older_than_11_2_0}/g" \ | sed -e "s/@CC@/${cc_esc}/g" \ | sed -e "s/@CXX@/${cxx_esc}/g" \ | sed -e "s/@RANLIB@/${ranlib_esc}/g" \ @@ -3495,6 +3541,7 @@ main() | sed -e "s/@enable_arg_max_hack@/${enable_arg_max_hack}/g" \ | sed -e "s/@enable_static@/${enable_static}/g" \ | sed -e "s/@enable_shared@/${enable_shared}/g" \ + | sed -e "s/@enable_rpath@/${enable_rpath}/g" \ | sed -e "s/@export_shared@/${export_shared}/g" \ | sed -e "s/@enable_blas@/${enable_blas}/g" \ | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ diff --git a/docs/BLISObjectAPI.md b/docs/BLISObjectAPI.md index 9a06e29a49..5e8ed3d8fb 100644 --- a/docs/BLISObjectAPI.md +++ b/docs/BLISObjectAPI.md @@ -2336,16 +2336,9 @@ char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ) ``` Possible implementation (ie: the `ind_t method` argument) types are: - * `BLIS_3MH`: Implementation based on the 3m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_3M1`: Implementation based on the 3m method applied within the 1st loop around the microkernel. - * `BLIS_4MH`: Implementation based on the 4m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_4M1B`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that the 1st loop is fissured into two loops, the first of which multiplies the real part of the current micropanel of packed matrix B (against all real and imaginary parts of packed matrix A), and the second of which multiplies the imaginary part of the current micropanel of packed matrix B. - * `BLIS_4M1A`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that real and imaginary components of the current micropanels are completely used before proceeding to the next virtual microkernel invocation. * `BLIS_1M`: Implementation based on the 1m method. (This is the default induced method when real domain kernels are present but complex kernels are missing.) * `BLIS_NAT`: Implementation based on "native" execution (ie: NOT an induced method). -**NOTE**: `BLIS_3M3` and `BLIS_3M2` have been deprecated from the `typedef enum` of `ind_t`, and `BLIS_4M1B` is also effectively no longer available, though the `typedef enum` value still exists. - Possible microkernel types (ie: the return values for `bli_info_get_*_ukr_impl_string()`) are: * `BLIS_REFERENCE_UKERNEL` (`"refrnce"`): This value is returned when the queried microkernel is provided by the reference implementation. * `BLIS_VIRTUAL_UKERNEL` (`"virtual"`): This value is returned when the queried microkernel is driven by a the "virtual" microkernel provided by an induced method. This happens for any `method` value that is not `BLIS_NAT` (ie: native), but only applies to the complex domain. diff --git a/docs/BLISTypedAPI.md b/docs/BLISTypedAPI.md index a29870169d..6279a5df9a 100644 --- a/docs/BLISTypedAPI.md +++ b/docs/BLISTypedAPI.md @@ -2015,16 +2015,9 @@ char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ) ``` Possible implementation (ie: the `ind_t method` argument) types are: - * `BLIS_3MH`: Implementation based on the 3m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_3M1`: Implementation based on the 3m method applied within the 1st loop around the microkernel. - * `BLIS_4MH`: Implementation based on the 4m method applied at the highest level, outside the 5th loop around the microkernel. - * `BLIS_4M1B`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that the 1st loop is fissured into two loops, the first of which multiplies the real part of the current micropanel of packed matrix B (against all real and imaginary parts of packed matrix A), and the second of which multiplies the imaginary part of the current micropanel of packed matrix B. - * `BLIS_4M1A`: Implementation based on the 4m method applied within the 1st loop around the microkernel. Computation is ordered such that real and imaginary components of the current micropanels are completely used before proceeding to the next virtual microkernel invocation. * `BLIS_1M`: Implementation based on the 1m method. (This is the default induced method when real domain kernels are present but complex kernels are missing.) * `BLIS_NAT`: Implementation based on "native" execution (ie: NOT an induced method). -**NOTE**: `BLIS_3M3` and `BLIS_3M2` have been deprecated from the `typedef enum` of `ind_t`, and `BLIS_4M1B` is also effectively no longer available, though the `typedef enum` value still exists. - Possible microkernel types (ie: the return values for `bli_info_get_*_ukr_impl_string()`) are: * `BLIS_REFERENCE_UKERNEL` (`"refrnce"`): This value is returned when the queried microkernel is provided by the reference implementation. * `BLIS_VIRTUAL_UKERNEL` (`"virtual"`): This value is returned when the queried microkernel is driven by a the "virtual" microkernel provided by an induced method. This happens for any `method` value that is not `BLIS_NAT` (ie: native), but only applies to the complex domain. diff --git a/docs/CMakeBuildSystem.md b/docs/CMakeBuildSystem.md index 92b85cf432..4f8091c8d9 100644 --- a/docs/CMakeBuildSystem.md +++ b/docs/CMakeBuildSystem.md @@ -72,9 +72,9 @@ On Windows, specify Visual Studio generator using cmake -G "Visual Studio 17 2022" ``` -For the rest of this documentation, we will use the platform-agnostic commands to build the libraries, but the usual make commands can be used instead. On the following command snippets we ommit specifying the generator, but one can use their prefered way of building using common CMake practices. +For the rest of this documentation, we will use the platform-agnostic commands to build the libraries, but the usual make commands can be used instead. On the following command snippets we ommit specifying the generator, but one can use their prefered way of building using common CMake practices. -### Choosing a configuration +### Choosing a configuration This step is equivalent to running `./configure ` using the Make system. In this case, simply run: ``` @@ -139,7 +139,7 @@ Please note that CMake does not provide functionality to uninstall targets. ## Available targets -The BLIS CMake system aims to be combatible with the current `make` system. For that reason, it implements the same targets for the generation of libraries and the tests. The table of avalable targets can be found below. +The BLIS CMake system aims to be combatible with the current `make` system. For that reason, it implements the same targets for the generation of libraries and the tests. The table of available targets can be found below. | target | Description | |:----------------|:---------------------------------------------------| @@ -159,8 +159,9 @@ The BLIS CMake system aims to be combatible with the current `make` system. For | `testsuite` | Same as `testblis`. | | `testblas` | Run the BLAS test drivers with default parameters (runs for a few seconds). | | `checkbliscpp` | Run the BLIS C++ tests (runs for a few seconds). | +| `coverage` | Run the code-coverage that generates html report (runs for 5-10 minutes). | -**_NOTE:_** +**_NOTE:_** Using those targets sets the environment appropriately, so copying the input files and/or the DLL in case of Windows builds is not required. ### Running the testsuites @@ -172,13 +173,13 @@ Using those targets sets the environment appropriately, so copying the input fil The CMake system is designed to closely relate to the BLIS Make system. Assuming that a user has followed the steps in [Configuration How To](ConfigurationHowTo.md), adding the new configuration on the CMake system requires the following steps: * Add a `make_defs.cmake` file which is equivalent to `make_defs.mk`. One can see `blis/config/zen/make_defs.cmake` and `blis/config/zen/make_defs.mk` for an example. -* Update `blis/CMakeLists.txt` to remove the error for the particular new configuration and to add the option in `set_property()` so that it appears in cmake-gui. +* Update `blis/CMakeLists.txt` to remove the error for the particular new configuration and to add the option in `set_property()` so that it appears in cmake-gui. ## Some examples In this section we provide some examples for users that are familiar with the build system based in Makefiles and want to try the new CMake system. -**_NOTE:_** +**_NOTE:_** The CMake system generates the shared libraries by default. To build the static libraries, you need to specify the corresponding CMake variable below ``` cmake .. -DBUILD_SHARED_LIBS=OFF -DBLIS_CONFIG_FAMILY=amdzen @@ -207,7 +208,7 @@ cmake .. -G "Visual Studio 17 2022" -TClangCl -DENABLE_THREADING=openmp -DINT_SI ### Example 2: single-threaded ILP64 libraries for amdzen configuration with aocl_gemm addon enabled and default compiler -**_NOTE:_** +**_NOTE:_** Addon functionality is currently available only on Linux. * With configure script: @@ -220,6 +221,71 @@ Addon functionality is currently available only on Linux. cmake .. -DENABLE_THREADING=no -DINT_SIZE=64 -DBLAS_INT_SIZE=64 -DENABLE_ADDON=aocl_gemm -DBLIS_CONFIG_FAMILY=amdzen ``` +### Bench +* Bench is used to measure performance. The bench targets depend on BLIS library, which is built depending on the cmake configuration. + +## 1. Bench CMake Configuration + +## 1.1.Move to "bench" folder within blis_build dir created during configuring cmake. + +## 1.2.Now build bench selecting the targets +# 1.2.1.To build blis targets +* To build the benchmark executables with the BLIS library built from CMake project use +``` +$ cmake .. +$ cmake --build . --target bench_blis #builds blis extension executables +``` + +* To build the benchmark executables with any BLIS package provide a path to the installation using +``` +$ cmake .. -DBLIS_INSTALL_PATH=/BLIS_installation_path +$ cmake --build . --target bench_blis #builds blis extension executables +``` + +## 1.2.2.To build MKL targets +* To build the benchmark executables with MKLROOT use +``` +$ cmake .. +$ cmake --build . --target bench_mkl #builds mkl extension executables +``` + +* If MKLROOT is not set, then set MKL_PATH and build the benchmark executables using +``` +$ cmake .. -DMKL_PATH=/path_to_MKL_library +$ cmake --build . --target bench_mkl #builds mkl extension executables +``` + +## 1.2.3.To build openblas targets +* To build benchmark executables for Openblas,set the OPENBLAS_PATH and build using +``` +$ cmake .. -DOPENBLAS_PATH=/path_to_Openblas +$ cmake --build . --target bench_openblas #builds openblas extension executables +``` + +## 1.2.4.To build for all targets +* To build for all benchmark executables set the MKL_PATH,OPENBLAS_PATH, then build using +``` +$ cmake .. -DMKL_PATH=/path_to_MKL_library -DOPENBLAS_PATH=/path_to_Openblas +$ cmake --build . --target benchmark #builds for all targets +``` + +## 2.To measure performance for "bench_aocl_gemm" only when lpgemm is configured during cmake. +``` +cmake .. -DENABLE_ADDON="aocl_gemm" +``` + +# 2.1.Move to "bench_aocl_gemm" folder within blis_build/bench folder. + +# 2.2.Now build bench_aocl_gemm +``` +$ cmake --build . or cmake --build . --target benchmark_lpgemm +``` + +## 3.Run any of the bench executable +``` + ./ ../../bench/inputfile.txt outfile.txt +``` + ## Conclusion -The BLIS CMake system is developed and maintained by AMD. You can contact us on the email-id toolchainsupport@amd.com. You can also raise any issue/suggestion on the git-hub repository at https://github.com/amd/blis/issues. \ No newline at end of file +The BLIS CMake system is developed and maintained by AMD. You can contact us on the email-id toolchainsupport@amd.com. You can also raise any issue/suggestion on the git-hub repository at https://github.com/amd/blis/issues. diff --git a/docs/Doxyfile b/docs/Doxyfile index 36ae286238..2fafbc5049 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -44,7 +44,7 @@ DOXYFILE_ENCODING = UTF-8 # title of most generated pages and in a few other places. # The default value is: My Project. -PROJECT_NAME = AOCL-BLIS +PROJECT_NAME = AOCL-BLAS # The PROJECT_NUMBER tag can be used to enter a project or revision number. This # could be handy for archiving the generated documentation or if some version diff --git a/docs/HardwareSupport.md b/docs/HardwareSupport.md index 32e5c4a630..944cfa8ee1 100644 --- a/docs/HardwareSupport.md +++ b/docs/HardwareSupport.md @@ -24,7 +24,7 @@ A few remarks / reminders: | AMD Steamroller (AVX/FMA3) | `steamroller` | `sdcz` | | | AMD Excavator (AVX/FMA3) | `excavator` | `sdcz` | | | AMD Zen (AVX/FMA3) | `zen` | `sdcz` | `sd` | -| Intel Core2 (SSE3) | `penryn` | `sd` | `d` | +| Intel Core2 (SSE3) | `penryn` | `sd` | `d` | | Intel Sandy/Ivy Bridge (AVX/FMA3) | `sandybridge` | `sdcz` | | | Intel Haswell, Broadwell (AVX/FMA3) | `haswell` | `sdcz` | `sd` | | Intel Sky/Kaby/CoffeeLake (AVX/FMA3) | `haswell` | `sdcz` | `sd` | @@ -35,6 +35,8 @@ A few remarks / reminders: | ARMv7 Cortex-A15 (NEON) | `cortex-a15` | `sd` | | | ARMv8 Cortex-A53 (NEON) | `cortex-a53` | `sd` | | | ARMv8 Cortex-A57 (NEON) | `cortex-a57` | `sd` | | +| ARMv8.1 ThunderX2 (NEON) | `thunderx2` | `sd` | | +| ARMv8.1 A64FX (SVE) | `a64fx` | `d` | | | IBM Blue Gene/Q (QPX int) | `bgq` | `d` | | | IBM Power7 (QPX int) | `power7` | `d` | | | template (C99) | `template` | `sdcz` | `sdcz` | diff --git a/docs/Main_Page.md b/docs/Main_Page.md index 39c2e12c85..9e9fe0925c 100644 --- a/docs/Main_Page.md +++ b/docs/Main_Page.md @@ -1,5 +1,5 @@ @mainpage -# Welcome to AOCL-BLIS +# Welcome to AOCL-BLAS --- @@ -14,9 +14,9 @@ ## Introduction - AOCL BLIS BLIS is a portable software framework for instantiating high-performance BLAS-like dense linear algebra libraries. The framework was designed to isolate essential kernels of computation that, when optimized, immediately enable optimized implementations of most of its commonly used and computationally intensive operations. BLIS is written in ISO C99 and available under a new/modified/3-clause BSD license. While BLIS exports a new BLAS-like API, it also includes a BLAS compatibility layer which gives application developers access to BLIS implementations via traditional BLAS routine calls. An object-based API unique to BLIS is also available. + AOCL BLAS BLIS is a portable software framework for instantiating high-performance BLAS-like dense linear algebra libraries. The framework was designed to isolate essential kernels of computation that, when optimized, immediately enable optimized implementations of most of its commonly used and computationally intensive operations. BLIS is written in ISO C99 and available under a new/modified/3-clause BSD license. While BLIS exports a new BLAS-like API, it also includes a BLAS compatibility layer which gives application developers access to BLIS implementations via traditional BLAS routine calls. An object-based API unique to BLIS is also available. -How to Download BLIS +How to Download AOCL BLAS -------------------- There are a few ways to download BLIS. We list the most common four ways below. @@ -135,4 +135,4 @@ omitted (mostly for brevity's sake) and thus more examples could be written. ## CONTACTS -AOCL BLIS is developed and maintained by AMD. You can contact us on the email-id [aoclsupport@amd.com](mailto:aoclsupport@amd.com) +AOCL BLAS is developed and maintained by AMD. You can contact us on the email-id [aoclsupport@amd.com](mailto:aoclsupport@amd.com) diff --git a/docs/Performance.md b/docs/Performance.md index 051be7aea9..f4992d1dee 100644 --- a/docs/Performance.md +++ b/docs/Performance.md @@ -550,9 +550,9 @@ The `runthese.m` file will contain example invocations of the function. * Operating system: RHEL 8.3 * Page size: 256 bytes * Compiler: gcc 10.1.0 -* Results gathered: 2 April 2021; BLIS and SSL2 updated on 20 May 2021 +* Results gathered: 2 April 2021; BLIS and SSL2 updated on 21 Sept 2021 * Implementations tested: - * BLIS 61584de (post-0.8.1) + * BLIS b05279d (post-0.8.1) * configured with: * `../configure -t none CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (single-threaded) * `../configure -t openmp CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (multithreaded) @@ -574,7 +574,7 @@ The `runthese.m` file will contain example invocations of the function. * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12` * Multithreaded (48 core) execution requested via `export OMP_NUM_THREADS=48` * **NOTE**: While this version of ARMPL does provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm` (with the exception `dtrsm`), but these implementations yield very low performance, and their long run times led us to skip collecting these data altogether. - * Fujitsu SSL2 (Fujitsu toolchain 1.2.31) + * Fujitsu SSL2 (Fujitsu toolchain 1.2.33) * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1 NPARALLEL=1` * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12 NPARALLEL=12` * Multithreaded (48 core) execution requested via `export OMP_NUM_THREADS=48 NPARALLEL=48` diff --git a/docs/Sandboxes.md b/docs/Sandboxes.md index 8f404d0a6b..cbc0add53e 100644 --- a/docs/Sandboxes.md +++ b/docs/Sandboxes.md @@ -17,13 +17,9 @@ Simply put, a sandbox in BLIS provides an alternative implementation to the `gemm` operation. To get a little more specific, a sandbox provides an alternative implementation -to the function `bli_gemmnat()`, which is the object-based API call for -computing the `gemm` operation via native execution. - -**Note**: Native execution simply means that an induced method will not be used. -It's what you probably already think of when you think of implementing the -`gemm` operation: a series of loops around an optimized (usually assembly-based) -microkernel with some packing functions thrown in at various levels. +to the function `bli_gemm_ex()`, which is the +[expert interface](BLISObjectAPI.md##basic-vs-expert-interfaces) for calling the +[object-based API](BLISObjectAPI.md#gemm) for the `gemm` operation. Why sandboxes? Sometimes you want to experiment with tweaks or changes to the `gemm` operation, but you want to do so in a simple environment rather than @@ -45,18 +41,11 @@ corresponds to a sub-directory of `sandbox` named `gemmlike`. (Reminder: the `auto` argument is the configuration target and thus unrelated to sandboxes.) -NOTE: If you want your sandbox implementation to handle *all* problem -sizes and shapes, you'll need to disable the skinny/unpacked "sup" -sub-framework within BLIS, which is enabled by default. This can be -done by passing the `--disable-sup-handling` option to configure: -``` -$ ./configure --enable-sandbox=gemmlike --disable-sup-handling auto -``` -If you leave sup enabled, the sup implementation will, at runtime, detect -and handle certain smaller problem sizes upstream of where BLIS calls -`bli_gemmnat()` while all other problems will fall to your sandbox -implementation. Thus, you should only leave sup enabled if you are fine -with those smaller problems being handled by sup. +NOTE: Using your own sandbox implementation means that BLIS will call your +sandbox for *all* problem sizes and shapes, for *all* datatypes supported +by BLIS. If you intend to only implement a subset of this functionality +within your sandbox, you should be sure to redirect execution back into +the core framework for the parts that you don't wish to reimplement yourself. As `configure` runs, you should get output that includes lines similar to: @@ -67,13 +56,12 @@ configure: sandbox/gemmlike And when you build BLIS, the last files to be compiled will be the source code in the specified sandbox: ``` -Compiling obj/haswell/sandbox/gemmlike/bli_gemmnat.o ('haswell' CFLAGS for sandboxes) Compiling obj/haswell/sandbox/gemmlike/bls_gemm.o ('haswell' CFLAGS for sandboxes) Compiling obj/haswell/sandbox/gemmlike/bls_gemm_bp_var1.o ('haswell' CFLAGS for sandboxes) ... ``` That's it! After the BLIS library is built, it will contain your chosen -sandbox's implementation of `bli_gemmnat()` instead of the default +sandbox's implementation of `bli_gemm_ex()` instead of the default BLIS implementation. ## Sandbox rules @@ -97,7 +85,7 @@ Note that `blis.h` already contains all of its definitions inside of an `extern "C"` block, so you should be able to `#include "blis.h"` from your C++11 source code without any issues. -3. All of your code to replace BLIS's default implementation of `bli_gemmnat()` +3. All of your code to replace BLIS's default implementation of `bli_gemm_ex()` should reside in the named sandbox directory, or some directory therein. (Obviously.) For example, the "gemmlike" sandbox is located in `sandbox/gemmlike`. All of the code associated with this sandbox will be @@ -105,7 +93,7 @@ contained within `sandbox/gemmlike`. Note that you absolutely *may* include additional code and interfaces within the sandbox, if you wish -- code and interfaces that are not directly or indirectly needed for satisfying the the "contract" set forth by the sandbox (i.e., including a local definition -of`bli_gemmnat()`). +of`bli_gemm_ex()`). 4. The *only* header file that is required of your sandbox is `bli_sandbox.h`. It must be named `bli_sandbox.h` because `blis.h` will `#include` this file @@ -119,12 +107,12 @@ you should only place things (e.g. prototypes or type definitions) in (b) an *application* that calls your sandbox-enabled BLIS library. Usually, neither of these situations will require any of your local definitions since those local definitions are only needed to define your sandbox -implementation of `bli_gemmnat()`, and this function is already prototyped by +implementation of `bli_gemm_ex()`, and this function is already prototyped by BLIS. *But if you are adding additional APIs and/or operations to the sandbox -that are unrelated to `bli_gemmnat()`, then you'll want to `#include` those +that are unrelated to `bli_gemm_ex()`, then you'll want to `#include` those function prototypes from within `bli_sandbox.h`* -5. Your definition of `bli_gemmnat()` should be the **only function you define** +5. Your definition of `bli_gemm_ex()` should be the **only function you define** in your sandbox that begins with `bli_`. If you define other functions that begin with `bli_`, you risk a namespace collision with existing framework functions. To guarantee safety, please prefix your locally-defined sandbox @@ -147,9 +135,9 @@ For example, with a BLIS sandbox you **can** do the following kinds of things: kernels, which can already be customized within each sub-configuration); - try inlining your functions manually; - pivot away from using `obj_t` objects at higher algorithmic level (such as - immediately after calling `bli_gemmnat()`) to try to avoid some overhead; + immediately after calling `bli_gemm_ex()`) to try to avoid some overhead; - create experimental implementations of new BLAS-like operations (provided - that you also provide an implementation of `bli_gemmnat()`). + that you also provide an implementation of `bli_gemm_ex()`). You **cannot**, however, use a sandbox to do the following kinds of things: - define new datatypes (half-precision, quad-precision, short integer, etc.) @@ -167,8 +155,8 @@ Another important limitation is the fact that the build system currently uses # Example framework CFLAGS used by 'haswell' sub-configuration -O3 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L -I./include/haswell -I./frame/3/ --I./frame/ind/ukernels/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ --I./frame/include -DBLIS_VERSION_STRING=\"0.3.2-51\" +-I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.3.2-51\" ``` which are likely more general-purpose than the `CFLAGS` used for, say, optimized kernels or even reference kernels. @@ -176,8 +164,8 @@ optimized kernels or even reference kernels. # Example optimized kernel CFLAGS used by 'haswell' sub-configuration -O3 -mavx2 -mfma -mfpmath=sse -march=core-avx2 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L -I./include/haswell --I./frame/3/ -I./frame/ind/ukernels/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ --I./frame/include -DBLIS_VERSION_STRING=\"0.3.2-51\" +-I./frame/3/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.3.2-51\" ``` (To see precisely which flags are being employed for any given file, enable verbosity at compile-time via `make V=1`.) Compiling sandboxes with these more diff --git a/docs/Testsuite.md b/docs/Testsuite.md index 917a7e4a7c..d34955f0ad 100644 --- a/docs/Testsuite.md +++ b/docs/Testsuite.md @@ -128,11 +128,6 @@ sdcz # Datatype(s) to test: 300 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test -1 # 3mh ('1' = enable; '0' = disable) -1 # 3m1 ('1' = enable; '0' = disable) -1 # 4mh ('1' = enable; '0' = disable) -1 # 4m1b ('1' = enable; '0' = disable) -1 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: @@ -169,7 +164,7 @@ _**Test gemm with mixed-precision operands?**_ This boolean determines whether ` _**Problem size.**_ These values determine the first problem size to test, the maximum problem size to test, and the increment between problem sizes. Note that the maximum problem size only bounds the range of problem sizes; it is not guaranteed to be tested. Example: If the initial problem size is 128, the maximum is 1000, and the increment is 64, then the last problem size to be tested will be 960. -_**Complex level-3 implementations to test.**_ With the exception of the switch marked `native`, these switches control whether experimental complex domain implementations are tested (when applicable). These implementations employ induced methods complex matrix multiplication and apply to some (though not all) of the level-3 operations. If you don't know what these are, you can ignore them. The `native` switch corresponds to native execution of complex domain level-3 operations, which we test by default. We also test the `1m` method, since it is the induced method of choice when complex microkernels are not available. Note that all of these induced method tests (including `native`) are automatically disabled if the `c` and `z` datatypes are disabled. +_**Complex level-3 implementations to test.**_ This section lists which complex domain implementations of level-3 operations are tested. If you don't know what these are, you can ignore them. The `native` switch corresponds to native execution of complex domain level-3 operations, which we test by default. We also test the `1m` method, since it is the induced method of choice when optimized complex microkernels are not available. Note that all of these induced method tests (including `native`) are automatically disabled if the `c` and `z` datatypes are disabled. _**Simulate application-level threading.**_ This setting specifies the number of threads the testsuite will spawn, and is meant to allow the user to exercise BLIS as a multithreaded application might if it were to make multiple concurrent calls to BLIS operations. (Note that the threading controlled by this option is orthogonal to, and has no effect on, whatever multithreading may be employed _within_ BLIS, as specified by the environment variables described in the [Multithreading](Multithreading.md) documentation.) When this option is set to 1, the testsuite is run with only one thread. When set to n > 1 threads, the spawned threads will parallelize (in round-robin fashion) the total set of tests specified by the testsuite input files, executing them in roughly the same order as that of a sequential execution. diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf index e273d1d098..4d27944170 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png index 1316647d65..f51548effb 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf index b311e0f5db..845dfaf862 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png index c2719da87a..08e46c6723 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.pdf b/docs/graphs/large/l3_perf_a64fx_nt1.pdf index 6f0b8c74fc..97a31560a1 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_nt1.pdf and b/docs/graphs/large/l3_perf_a64fx_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.png b/docs/graphs/large/l3_perf_a64fx_nt1.png index f2cb381786..0b7c2d72aa 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_nt1.png and b/docs/graphs/large/l3_perf_a64fx_nt1.png differ diff --git a/frame/1/bli_l1v_tapi.c b/frame/1/bli_l1v_tapi.c index b7637e7ebd..452e9ce156 100644 --- a/frame/1/bli_l1v_tapi.c +++ b/frame/1/bli_l1v_tapi.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -135,6 +135,14 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ) \ { \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ +\ + /* Early exit in case n is 0, or alpha is 0 and beta is 1 */ \ + if ( bli_zero_dim1( n ) || \ + ( PASTEMAC( ch, eq0 )( *alpha ) && PASTEMAC( ch, eq1 )( *beta ) ) ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ \ bli_init_once(); \ \ @@ -162,7 +170,6 @@ void PASTEMAC2(ch,opname,EX_SUF) \ INSERT_GENTFUNC_BASIC( axpbyv, BLIS_AXPBYV_KER ) - #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, kerid ) \ \ @@ -203,6 +210,53 @@ void PASTEMAC2(ch,opname,EX_SUF) \ } INSERT_GENTFUNC_BASIC( axpyv, BLIS_AXPYV_KER ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, kerid ) \ +\ +void PASTEMAC2(ch,opname,EX_SUF) \ + ( \ + conj_t conjx, \ + dim_t n, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy \ + BLIS_TAPI_EX_PARAMS \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ +\ + /* The behaviour is undefined when increments are negative or 0 */ \ + /* So, return early */ \ + if( ( incx <= 0 ) || ( incy <= 0 ) ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ + bli_init_once(); \ +\ + BLIS_TAPI_EX_DECLS \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Obtain a valid context from the gks if necessary. */ \ + if ( cntx == NULL ) \ + cntx = bli_gks_query_cntx(); \ +\ + PASTECH2(ch,opname,_ker_ft) f = bli_cntx_get_l1v_ker_dt( dt, kerid, cntx ); \ +\ + f \ + ( \ + conjx, \ + n, \ + alpha, \ + x, incx, \ + y, incy, \ + cntx \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + INSERT_GENTFUNC_BASIC( scal2v, BLIS_SCAL2V_KER ) diff --git a/frame/1m/bli_l1m_ft_ker.h b/frame/1m/bli_l1m_ft_ker.h index e8ebdec0d8..1146ca7d2c 100644 --- a/frame/1m/bli_l1m_ft_ker.h +++ b/frame/1m/bli_l1m_ft_ker.h @@ -110,28 +110,6 @@ typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ INSERT_GENTDEF( unpackm_cxk ) -// packm_3mis_ker -// packm_4mi_ker - -#undef GENTDEF -#define GENTDEF( ctype, ch, opname, tsuf ) \ -\ -typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - -INSERT_GENTDEF( packm_cxk_3mis ) -INSERT_GENTDEF( packm_cxk_4mi ) - -// packm_rih_ker // packm_1er_ker #undef GENTDEF @@ -150,12 +128,8 @@ typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ cntx_t* restrict cntx \ ); -INSERT_GENTDEF( packm_cxk_rih ) INSERT_GENTDEF( packm_cxk_1er ) - - - #endif diff --git a/frame/1m/bli_l1m_ker.h b/frame/1m/bli_l1m_ker.h index f79a292d33..76d51af2b0 100644 --- a/frame/1m/bli_l1m_ker.h +++ b/frame/1m/bli_l1m_ker.h @@ -74,51 +74,6 @@ INSERT_GENTPROT_BASIC0( unpackm_14xk_ker_name ) INSERT_GENTPROT_BASIC0( unpackm_16xk_ker_name ) -// 3mis packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_3MIS_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_3mis_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_3mis_ker_name ) - - -// 4mi packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_4MI_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_4mi_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_4mi_ker_name ) - - -// rih packm kernels - -#undef GENTPROT -#define GENTPROT PACKM_RIH_KER_PROT - -INSERT_GENTPROT_BASIC0( packm_2xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_4xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_6xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_8xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_10xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_12xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_14xk_rih_ker_name ) -INSERT_GENTPROT_BASIC0( packm_16xk_rih_ker_name ) - - // 1e/1r packm kernels #undef GENTPROT diff --git a/frame/1m/bli_l1m_ker_prot.h b/frame/1m/bli_l1m_ker_prot.h index 3bbdc2c253..02d3296220 100644 --- a/frame/1m/bli_l1m_ker_prot.h +++ b/frame/1m/bli_l1m_ker_prot.h @@ -70,58 +70,6 @@ void PASTEMAC(ch,varname) \ ); -// 3mis packm kernels - -#define PACKM_3MIS_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - - -// 4mi packm kernels - -#define PACKM_4MI_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - - -// rih packm kernels - -#define PACKM_RIH_KER_PROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ); - - // 1e/1r packm kernels #define PACKM_1ER_KER_PROT( ctype, ch, varname ) \ diff --git a/frame/1m/packm/bli_packm.h b/frame/1m/packm/bli_packm.h index 20a2e373a9..8c9f5bbd18 100644 --- a/frame/1m/packm/bli_packm.h +++ b/frame/1m/packm/bli_packm.h @@ -43,15 +43,9 @@ #include "bli_packm_var.h" #include "bli_packm_struc_cxk.h" -#include "bli_packm_struc_cxk_4mi.h" -#include "bli_packm_struc_cxk_3mis.h" -#include "bli_packm_struc_cxk_rih.h" #include "bli_packm_struc_cxk_1er.h" #include "bli_packm_cxk.h" -#include "bli_packm_cxk_4mi.h" -#include "bli_packm_cxk_3mis.h" -#include "bli_packm_cxk_rih.h" #include "bli_packm_cxk_1er.h" #include "bli_pack_full.h" diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index c720317b96..6f95b58999 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,31 +71,10 @@ static func_t packm_struc_cxk_kers[BLIS_NUM_PACK_SCHEMA_TYPES] = // 0000 row/col panels { { bli_spackm_struc_cxk, bli_cpackm_struc_cxk, bli_dpackm_struc_cxk, bli_zpackm_struc_cxk, } }, -// 0001 row/col panels: 4m interleaved - { { NULL, bli_cpackm_struc_cxk_4mi, - NULL, bli_zpackm_struc_cxk_4mi, } }, -// 0010 row/col panels: 3m interleaved - { { NULL, bli_cpackm_struc_cxk_3mis, - NULL, bli_zpackm_struc_cxk_3mis, } }, -// 0011 row/col panels: 4m separated (NOT IMPLEMENTED) - { { NULL, NULL, - NULL, NULL, } }, -// 0100 row/col panels: 3m separated - { { NULL, bli_cpackm_struc_cxk_3mis, - NULL, bli_zpackm_struc_cxk_3mis, } }, -// 0101 row/col panels: real only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 0110 row/col panels: imaginary only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 0111 row/col panels: real+imaginary only - { { NULL, bli_cpackm_struc_cxk_rih, - NULL, bli_zpackm_struc_cxk_rih, } }, -// 1000 row/col panels: 1m-expanded (1e) +// 0001 row/col panels: 1m-expanded (1e) { { NULL, bli_cpackm_struc_cxk_1er, NULL, bli_zpackm_struc_cxk_1er, } }, -// 1001 row/col panels: 1m-reordered (1r) +// 0010 row/col panels: 1m-reordered (1r) { { NULL, bli_cpackm_struc_cxk_1er, NULL, bli_zpackm_struc_cxk_1er, } }, }; @@ -203,17 +182,22 @@ void bli_packm_blk_var1 // Acquire the buffer to the kappa chosen above. buf_kappa = bli_obj_buffer_for_1x1( dt_p, kappa_p ); } + +#ifdef BLIS_KERNELS_ZEN5 + // For DGEMM in ZEN5, scale by alpha during packing + if + ( + ( bli_obj_dt( p ) == BLIS_DOUBLE ) && + ( bli_arch_query_id() == BLIS_ARCH_ZEN5 ) + ) + { + bli_obj_scalar_detach( p, &kappa ); + // Reset the attached scalar (to 1.0). + bli_obj_scalar_reset( p ); + buf_kappa = kappa.buffer; + } +#endif - -#if 0 - if ( bli_is_4mi_packed( schema ) ) packm_kers = packm_struc_cxk_4mi_kers; - else if ( bli_is_3mi_packed( schema ) || - bli_is_3ms_packed( schema ) ) packm_kers = packm_struc_cxk_3mis_kers; - else if ( bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) packm_kers = packm_struc_cxk_rih_kers; - else packm_kers = packm_struc_cxk_kers; -#else // The original idea here was to read the packm_ukr from the context // if it is non-NULL. The problem is, it requires that we be able to // assume that the packm_ukr field is initialized to NULL, which it @@ -239,7 +223,6 @@ void bli_packm_blk_var1 //packm_kers = bli_cntx_packm_ukrs( cntx ); packm_kers = cntx_packm_kers; } -#endif #endif // Query the datatype-specific function pointer from the func_t object. @@ -337,8 +320,6 @@ void PASTEMAC(ch,varname) \ bool row_stored; \ bool col_stored; \ inc_t is_p_use; \ - dim_t ss_num; \ - dim_t ss_den; \ \ ctype* restrict c_use; \ ctype* restrict p_use; \ @@ -409,17 +390,6 @@ void PASTEMAC(ch,varname) \ m_panel_max = &panel_dim_max; \ n_panel_max = &panel_len_max_i; \ } \ -\ - /* Compute the storage stride scaling. Usually this is just 1. However, - in the case of interleaved 3m, we need to scale by 3/2, and in the - cases of real-only, imag-only, or summed-only, we need to scale by - 1/2. In both cases, we are compensating for the fact that pointer - arithmetic occurs in terms of complex elements rather than real - elements. */ \ - if ( bli_is_3mi_packed( schema ) ) { ss_num = 3; ss_den = 2; } \ - else if ( bli_is_3ms_packed( schema ) ) { ss_num = 1; ss_den = 2; } \ - else if ( bli_is_rih_packed( schema ) ) { ss_num = 1; ss_den = 2; } \ - else { ss_num = 1; ss_den = 1; } \ \ /* Compute the total number of iterations we'll need. */ \ n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ @@ -550,7 +520,7 @@ void PASTEMAC(ch,varname) \ /* NOTE: This value is usually LESS than ps_p because triangular matrices usually have several micro-panels that are shorter than a "full" micro-panel. */ \ - p_inc = ( is_p_use * ss_num ) / ss_den; \ + p_inc = is_p_use; \ } \ else if ( bli_is_herm_or_symm( strucc ) ) \ { \ @@ -706,29 +676,6 @@ bli_thread_barrier( thread ); \ bli_thread_barrier( thread ); \ } \ */ -/* - if ( bli_is_4mi_packed( schema ) ) { \ - printf( "packm_var2: is_p_use = %lu\n", is_p_use ); \ - if ( col_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - if ( row_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - } \ -*/ /* PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_rpi", *m_panel_max, *n_panel_max, \ ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ diff --git a/frame/1m/packm/bli_packm_cxk_3mis.c b/frame/1m/packm/bli_packm_cxk_3mis.c deleted file mode 100644 index 9435f6a736..0000000000 --- a/frame/1m/packm/bli_packm_cxk_3mis.c +++ /dev/null @@ -1,204 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Note that we use panel_dim_max, not panel_dim, to query the packm - kernel function pointer. This means that we always use the same - kernel, even for edge cases. */ \ - num_t dt = PASTEMAC(ch,type); \ - l1mkr_t ker_id = panel_dim_max; \ -\ - PASTECH2(ch,opname,_ker_ft) f; \ -\ - /* Query the context for the packm kernel corresponding to the current - panel dimension, or kernel id. If the id is invalid, the function will - return NULL. */ \ - f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ -\ - /* If there exists a kernel implementation for the micro-panel dimension - provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( f != NULL ) \ - { \ - f \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - panel_len_max, \ - kappa, \ - a, inca, lda, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ -\ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ - if ( panel_dim < panel_dim_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -\ - /* If panel_len < panel_len_max, then we zero those unused columns. */ \ - if ( panel_len < panel_len_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC0( packm_cxk_3mis ) - diff --git a/frame/1m/packm/bli_packm_cxk_4mi.c b/frame/1m/packm/bli_packm_cxk_4mi.c deleted file mode 100644 index c22f551cca..0000000000 --- a/frame/1m/packm/bli_packm_cxk_4mi.c +++ /dev/null @@ -1,146 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Note that we use panel_dim_max, not panel_dim, to query the packm - kernel function pointer. This means that we always use the same - kernel, even for edge cases. */ \ - num_t dt = PASTEMAC(ch,type); \ - l1mkr_t ker_id = panel_dim_max; \ -\ - PASTECH2(ch,opname,_ker_ft) f; \ -\ - /* Query the context for the packm kernel corresponding to the current - panel dimension, or kernel id. If the id is invalid, the function will - return NULL. */ \ - f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ -\ - /* If there exists a kernel implementation for the micro-panel dimension - provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( f != NULL ) \ - { \ - f \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - panel_len_max, \ - kappa, \ - a, inca, lda, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ -\ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ - if ( panel_dim != panel_dim_max ) \ - { \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -\ - /* If panel_len < panel_len_max, then we zero those unused columns. */ \ - if ( panel_len != panel_len_max ) \ - { \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC0( packm_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_cxk_rih.c b/frame/1m/packm/bli_packm_cxk_rih.c deleted file mode 100644 index 1f2c9f240a..0000000000 --- a/frame/1m/packm/bli_packm_cxk_rih.c +++ /dev/null @@ -1,151 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Note that we use panel_dim_max, not panel_dim, to query the packm - kernel function pointer. This means that we always use the same - kernel, even for edge cases. */ \ - num_t dt = PASTEMAC(ch,type); \ - l1mkr_t ker_id = panel_dim_max; \ -\ - PASTECH2(ch,opname,_ker_ft) f; \ -\ - /* Query the context for the packm kernel corresponding to the current - panel dimension, or kernel id. If the id is invalid, the function will - return NULL. */ \ - f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ -\ - /* If there exists a kernel implementation for the micro-panel dimension - provided, we invoke the implementation. Otherwise, we use scal2m. */ \ - if ( 0 && f != NULL ) \ - { \ - f \ - ( \ - conja, \ - schema, \ - panel_dim, \ - panel_len, \ - panel_len_max, \ - kappa, \ - a, inca, lda, \ - p, ldp, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Treat the micro-panel as panel_dim x panel_len and column-stored - (unit row stride). */ \ -\ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - panel_dim, \ - panel_len, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ - if ( panel_dim != panel_dim_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = panel_dim; \ - const dim_t m_edge = panel_dim_max - i; \ - const dim_t n_edge = panel_len_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -\ - /* If panel_len < panel_len_max, then we zero those unused columns. */ \ - if ( panel_len != panel_len_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = panel_len; \ - const dim_t m_edge = panel_dim_max; \ - const dim_t n_edge = panel_len_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC0( packm_cxk_rih ) - diff --git a/frame/1m/packm/bli_packm_init.c b/frame/1m/packm/bli_packm_init.c index a23da8c342..c2e0cfe389 100644 --- a/frame/1m/packm/bli_packm_init.c +++ b/frame/1m/packm/bli_packm_init.c @@ -113,52 +113,6 @@ siz_t bli_packm_init return 0; } -#if 0 - pack_t schema; - - if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - // We now ignore the pack_schema field in the control tree and - // extract the schema from the context, depending on whether we are - // preparing to pack a block of A or panel of B. For A and B, we must - // obtain the schema from the context since the induced methods reuse - // the same control trees used by native execution, and those induced - // methods specify the schema used by the current execution phase - // within the context (whereas the control tree does not change). - - if ( pack_buf_type == BLIS_BUFFER_FOR_A_BLOCK ) - { - schema = bli_cntx_schema_a_block( cntx ); - } - else if ( pack_buf_type == BLIS_BUFFER_FOR_B_PANEL ) - { - schema = bli_cntx_schema_b_panel( cntx ); - } - else // if ( pack_buf_type == BLIS_BUFFER_FOR_C_PANEL ) - { - schema = bli_cntl_packm_params_pack_schema( cntl ); - } - } - else // ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - // For native execution, we obtain the schema from the control tree - // node. (Notice that it doesn't matter if the pack_buf_type is for - // A or B.) - schema = bli_cntl_packm_params_pack_schema( cntl ); - } - // This is no longer needed now that we branch between native and - // non-native cases above. -#if 0 - if ( pack_buf_type == BLIS_BUFFER_FOR_C_PANEL ) - { - // If we get a request to pack C for some reason, it is likely - // not part of an induced method, and so it would be safe (and - // necessary) to read the pack schema from the control tree. - schema = bli_cntl_packm_params_pack_schema( cntl ); - } -#endif -#endif - // Prepare a few other variables based on properties of the control // tree. @@ -393,7 +347,7 @@ siz_t bli_packm_init_pack bli_is_panel_packed( schema ) ) { dim_t m_panel; - dim_t ps_p, ps_p_orig; + dim_t ps_p; // The panel dimension (for each datatype) should be equal to the // default (logical) blocksize multiple in the m dimension. @@ -418,58 +372,17 @@ siz_t bli_packm_init_pack // dimension of the matrix is not a whole multiple of MR. ps_p = cs_p * n_p_pad; - // As a general rule, we don't want micropanel strides to be odd. This - // is primarily motivated by our desire to support interleaved 3m - // micropanels, in which case we have to scale the panel stride - // by 3/2. That division by 2 means the numerator (prior to being - // scaled by 3) must be even. + // As a general rule, we don't want micropanel strides to be odd. + // NOTE: This safety feature *may* not be necessary anymore, but was + // definitely needed to support certain variations of the 3m method. if ( bli_is_odd( ps_p ) ) ps_p += 1; - // Preserve this early panel stride value for use later, if needed. - ps_p_orig = ps_p; - - // Here, we adjust the panel stride, if necessary. Remember: ps_p is - // always interpreted as being in units of the datatype of the object - // which is not necessarily how the micropanels will be stored. For - // interleaved 3m, we will increase ps_p by 50%, and for ro/io/rpi, - // we halve ps_p. Why? Because the macro-kernel indexes in units of - // the complex datatype. So these changes "trick" it into indexing - // the correct amount. - if ( bli_is_3mi_packed( schema ) ) - { - ps_p = ( ps_p * 3 ) / 2; - } - else if ( bli_is_3ms_packed( schema ) || - bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) - { - // The division by 2 below assumes that ps_p is an even number. - // However, it is possible that, at this point, ps_p is an odd. - // If it is indeed odd, we nudge it higher. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Despite the fact that the packed micropanels will contain - // real elements, the panel stride that we store in the obj_t - // (which is passed into the macro-kernel) needs to be in units - // of complex elements, since the macro-kernel will index through - // micropanels via complex pointer arithmetic for trmm/trsm. - // Since the indexing "increment" will be twice as large as each - // actual stored element, we divide the panel_stride by 2. - ps_p = ps_p / 2; - } - - // Set the imaginary stride (in units of fundamental elements) for - // 3m and 4m (separated or interleaved). We use ps_p_orig since - // that variable tracks the number of real part elements contained - // within each micropanel of the source matrix. Therefore, this - // is the number of real elements that must be traversed before - // reaching the imaginary part (3mi/4mi) of the packed micropanel, - // or the real part of the next micropanel (3ms). - if ( bli_is_3mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_4mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_3ms_packed( schema ) ) is_p = ps_p_orig * ( m_p_pad / m_panel ); - else is_p = 1; + // Set the imaginary stride (in units of fundamental elements). + // This is the number of real elements that must be traversed before + // reaching the imaginary part of the packed micropanel. NOTE: the + // imaginary stride is mostly vestigial and left over from the 3m + // and 4m implementations. + is_p = 1; // Store the strides and panel dimension in P. bli_obj_set_strides( rs_p, cs_p, p ); @@ -486,7 +399,7 @@ siz_t bli_packm_init_pack bli_is_panel_packed( schema ) ) { dim_t n_panel; - dim_t ps_p, ps_p_orig; + dim_t ps_p; // The panel dimension (for each datatype) should be equal to the // default (logical) blocksize multiple in the n dimension. @@ -512,58 +425,17 @@ siz_t bli_packm_init_pack // dimension of the matrix is not a whole multiple of NR. ps_p = m_p_pad * rs_p; - // As a general rule, we don't want micropanel strides to be odd. This - // is primarily motivated by our desire to support interleaved 3m - // micropanels, in which case we have to scale the panel stride - // by 3/2. That division by 2 means the numerator (prior to being - // scaled by 3) must be even. + // As a general rule, we don't want micropanel strides to be odd. + // NOTE: This safety feature *may* not be necessary anymore, but was + // definitely needed to support certain variations of the 3m method. if ( bli_is_odd( ps_p ) ) ps_p += 1; - // Preserve this early panel stride value for use later, if needed. - ps_p_orig = ps_p; - - // Here, we adjust the panel stride, if necessary. Remember: ps_p is - // always interpreted as being in units of the datatype of the object - // which is not necessarily how the micropanels will be stored. For - // interleaved 3m, we will increase ps_p by 50%, and for ro/io/rpi, - // we halve ps_p. Why? Because the macro-kernel indexes in units of - // the complex datatype. So these changes "trick" it into indexing - // the correct amount. - if ( bli_is_3mi_packed( schema ) ) - { - ps_p = ( ps_p * 3 ) / 2; - } - else if ( bli_is_3ms_packed( schema ) || - bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ) - { - // The division by 2 below assumes that ps_p is an even number. - // However, it is possible that, at this point, ps_p is an odd. - // If it is indeed odd, we nudge it higher. - if ( bli_is_odd( ps_p ) ) ps_p += 1; - - // Despite the fact that the packed micropanels will contain - // real elements, the panel stride that we store in the obj_t - // (which is passed into the macro-kernel) needs to be in units - // of complex elements, since the macro-kernel will index through - // micropanels via complex pointer arithmetic for trmm/trsm. - // Since the indexing "increment" will be twice as large as each - // actual stored element, we divide the panel_stride by 2. - ps_p = ps_p / 2; - } - - // Set the imaginary stride (in units of fundamental elements) for - // 3m and 4m (separated or interleaved). We use ps_p_orig since - // that variable tracks the number of real part elements contained - // within each micropanel of the source matrix. Therefore, this - // is the number of real elements that must be traversed before - // reaching the imaginary part (3mi/4mi) of the packed micropanel, - // or the real part of the next micropanel (3ms). - if ( bli_is_3mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_4mi_packed( schema ) ) is_p = ps_p_orig; - else if ( bli_is_3ms_packed( schema ) ) is_p = ps_p_orig * ( n_p_pad / n_panel ); - else is_p = 1; + // Set the imaginary stride (in units of fundamental elements). + // This is the number of real elements that must be traversed before + // reaching the imaginary part of the packed micropanel. NOTE: the + // imaginary stride is mostly vestigial and left over from the 3m + // and 4m implementations. + is_p = 1; // Store the strides and panel dimension in P. bli_obj_set_strides( rs_p, cs_p, p ); diff --git a/frame/1m/packm/bli_packm_struc_cxk_3mis.c b/frame/1m/packm/bli_packm_struc_cxk_3mis.c deleted file mode 100644 index 95908c8e7b..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_3mis.c +++ /dev/null @@ -1,842 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_3mis) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_3mis) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*rs_p; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -/* - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*cs_p; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t j = n_panel; \ - dim_t m_br = m_panel_max - i; \ - dim_t n_br = n_panel_max - j; \ - ctype_r* p_br_r = ( ctype_r* )p + (i )*rs_p + (j )*cs_p; \ - ctype_r* p_br_i = ( ctype_r* )p + is_p + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - one_r, \ - p_br_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - zero_r, \ - p_br_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_3mis, packm_cxk_3mis ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - doff_t diagoffc_abs; \ - dim_t i, j; \ - bool row_stored; \ - bool col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t p11_m = panel_dim; \ - dim_t p11_n = panel_dim; \ - inc_t rs_c11 = 2*rs_c; \ - inc_t cs_c11 = 2*cs_c; \ - dim_t j2 = diagoffc_abs; \ - ctype* c11 = ( ctype* )c + (j2 )*ldc; \ - ctype_r* p11 = ( ctype_r* )p_r + (j2 )*ldp; \ - ctype_r* c11_r = ( ctype_r* )c11; \ - ctype_r* c11_i = ( ctype_r* )c11 + 1; \ - ctype_r* p11_r = ( ctype_r* )p11; \ - ctype_r* p11_i = ( ctype_r* )p11 + is_p; \ - ctype_r* alpha_r = one_r; \ - ctype_r* alpha_i = ( bli_is_conj( conjc ) ? minus_one_r : one_r ); \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - /* Copy the real part of the stored triangle of c11 to p11_r. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_r, \ - c11_r, rs_c11, cs_c11, \ - p11_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Copy the imaginary part of the stored triangle of c11 to p11_i, - scaling by -1 if conjugation on c was requested. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_i, \ - c11_i, rs_c11, cs_c11, \ - p11_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* If source matrix c is Hermitian, we have to zero out the - imaginary components of the diagonal of p11 in case the - corresponding elements in c11 were not already zero. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,set0s)( *pi11_i ); \ - } \ - } \ -\ - /* Apply kappa to the part of p11 that corresponds to the stored - part of c11 that was copied above. */ \ - if ( bli_is_upper( uploc ) ) \ - { \ - PASTEMAC(ch,scalris_mxn_u) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,scalris_mxn_l) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ -\ - /* Update the p11 section of the ri panel. It simply needs - to contain the sum of p11_r + p11_i. */ \ - { \ - ctype_r* p11_rpi = p11_i + is_p; \ -\ - for ( j = 0; j < p11_n; ++j ) \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (j )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (j )*cs_p; \ - ctype_r* pi11_rpi = p11_rpi + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC(chr,add3s) \ - ( \ - *pi11_r, \ - *pi11_i, \ - *pi11_rpi \ - ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_3mis, packm_cxk_3mis ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p + 0; \ - ctype_r* p_i = ( ctype_r* )p + is_p; \ - ctype_r* p_rpi = ( ctype_r* )p + 2*is_p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ - ctype_r* p11_i = p_i + (j )*ldp; \ - ctype_r* p11_rpi = p_rpi + (j )*ldp; \ -\ - dim_t p11_m = m_panel; \ - dim_t p11_n = n_panel; \ -\ - dim_t min_p11_m_n; \ -\ - if ( diagoffp < 0 ) p11_m -= j; \ - else if ( diagoffp > 0 ) p11_n -= j; \ -\ - min_p11_m_n = bli_min( p11_m, p11_n ); \ -\ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ - dim_t i; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_i, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Update the diagonal of the p11 section of the rpi panel. - It simply needs to contain the sum of diagonals of p11_r - and p11_i. */ \ - for ( i = 0; i < min_p11_m_n; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_rpi = p11_rpi + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,add3s)( *pi11_r, *pi11_i, *pi11_rpi ); \ - } \ - } \ -\ - /* If requested, invert the diagonal of the packed panel. Note - that we do not need to update the ri panel since inverted - diagonals are only needed by trsm, which does not use the - p11 section of the ri panel. */ \ - if ( invdiag == TRUE ) \ - { \ - dim_t i; \ -\ - for ( i = 0; i < min_p11_m_n; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(ch,invertris)( *pi11_r, *pi11_i ); \ - } \ - } \ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). Note that this zero-filling is not needed for - trsm, since the unstored region is not referenced by the trsm - micro-kernel; however, zero-filling is needed for trmm, which - uses the gemm micro-kernel.*/ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_rpi, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_3mis, packm_cxk_3mis ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_3mis.h b/frame/1m/packm/bli_packm_struc_cxk_3mis.h deleted file mode 100644 index 01c8510a43..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_3mis.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_3mis ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_3mis ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_3mis ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_4mi.c b/frame/1m/packm/bli_packm_struc_cxk_4mi.c deleted file mode 100644 index 62c2d5086d..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_4mi.c +++ /dev/null @@ -1,757 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_4mi) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_4mi) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - is_p, ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t j = n_panel; \ - dim_t m_br = m_panel_max - i; \ - dim_t n_br = n_panel_max - j; \ - ctype_r* p_br_r = ( ctype_r* )p + (i )*rs_p + (j )*cs_p; \ - ctype_r* p_br_i = ( ctype_r* )p + is_p + (i )*rs_p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - one_r, \ - p_br_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - m_br, \ - n_br, \ - zero_r, \ - p_br_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_4mi, packm_cxk_4mi ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - doff_t diagoffc_abs; \ - dim_t i, j; \ - bool row_stored; \ - bool col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, is_p, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t p11_m = panel_dim; \ - dim_t p11_n = panel_dim; \ - inc_t rs_c11 = 2*rs_c; \ - inc_t cs_c11 = 2*cs_c; \ - dim_t j2 = diagoffc_abs; \ - ctype* c11 = ( ctype* )c + (j2 )*ldc; \ - ctype_r* p11 = ( ctype_r* )p_r + (j2 )*ldp; \ - ctype_r* c11_r = ( ctype_r* )c11; \ - ctype_r* c11_i = ( ctype_r* )c11 + 1; \ - ctype_r* p11_r = ( ctype_r* )p11; \ - ctype_r* p11_i = ( ctype_r* )p11 + is_p; \ - ctype_r* alpha_r = one_r; \ - ctype_r* alpha_i = ( bli_is_conj( conjc ) ? minus_one_r : one_r ); \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - /* Copy the real part of the stored triangle of c11 to p11_r. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_r, \ - c11_r, rs_c11, cs_c11, \ - p11_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* Copy the imaginary part of the stored triangle of c11 to p11_i, - scaling by -1 if conjugation on c was requested. */ \ - PASTEMAC2(chr,scal2m,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - BLIS_NONUNIT_DIAG, \ - uploc, \ - BLIS_NO_TRANSPOSE, \ - p11_m, \ - p11_n, \ - alpha_i, \ - c11_i, rs_c11, cs_c11, \ - p11_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ -\ - /* If source matrix c is Hermitian, we have to zero out the - imaginary components of the diagonal of p11 in case the - corresponding elements in c11 were not already zero. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - for ( i = 0; i < p11_m; ++i ) \ - { \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(chr,set0s)( *pi11_i ); \ - } \ - } \ -\ - /* Apply kappa to the part of p11 that corresponds to the stored - part of c11 that was copied above. */ \ - if ( bli_is_upper( uploc ) ) \ - { \ - PASTEMAC(ch,scalris_mxn_u) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,scalris_mxn_l) \ - ( \ - 0, \ - p11_m, \ - p11_n, \ - &kappa_r, \ - &kappa_i, \ - p11_r, \ - p11_i, rs_p, cs_p \ - ); \ - } \ -/* - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_r copied", m_panel_max, n_panel_max, \ - p_r + 0*is_p, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_i copied", m_panel_max, n_panel_max, \ - p_r + 1*is_p, rs_p, cs_p, "%4.1f", "" ); \ -*/ \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_4mi, packm_cxk_4mi ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, is_p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p; \ - ctype_r* p_i = ( ctype_r* )p + is_p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ - ctype_r* p11_i = p_i + (j )*ldp; \ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ - ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ -\ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setd,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - m_panel, \ - n_panel, \ - &kappa_i, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ -\ - /* If requested, invert the diagonal of the packed panel. */ \ - if ( invdiag == TRUE ) \ - { \ - dim_t i; \ -\ - for ( i = 0; i < panel_dim; ++i ) \ - { \ - ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ - ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ -\ - PASTEMAC(ch,invertris)( *pi11_r, *pi11_i ); \ - } \ - } \ -\ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). Note that this zero-filling is not needed for - trsm, since the unstored region is not referenced by the trsm - micro-kernel; however, zero-filling is needed for trmm, which - uses the gemm micro-kernel.*/ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_i, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_4mi, packm_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_4mi.h b/frame/1m/packm/bli_packm_struc_cxk_4mi.h deleted file mode 100644 index 5abfb585fd..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_4mi.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_4mi ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_4mi ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_4mi ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_rih.c b/frame/1m/packm/bli_packm_struc_cxk_rih.c deleted file mode 100644 index 59b34ede8a..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_rih.c +++ /dev/null @@ -1,625 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ) \ -{ \ - dim_t panel_dim; \ - dim_t panel_dim_max; \ - dim_t panel_len; \ - dim_t panel_len_max; \ - inc_t incc, ldc; \ - inc_t ldp; \ -\ -\ - /* Determine the dimensions and relative strides of the micro-panel - based on its pack schema. */ \ - if ( bli_is_col_packed( schema ) ) \ - { \ - /* Prepare to pack to row-stored column panel. */ \ - panel_dim = n_panel; \ - panel_dim_max = n_panel_max; \ - panel_len = m_panel; \ - panel_len_max = m_panel_max; \ - incc = cs_c; \ - ldc = rs_c; \ - ldp = rs_p; \ - } \ - else /* if ( bli_is_row_packed( schema ) ) */ \ - { \ - /* Prepare to pack to column-stored row panel. */ \ - panel_dim = m_panel; \ - panel_dim_max = m_panel_max; \ - panel_len = n_panel; \ - panel_len_max = n_panel_max; \ - incc = rs_c; \ - ldc = cs_c; \ - ldp = cs_p; \ - } \ -\ -\ - /* Handle micro-panel packing based on the structure of the matrix - being packed. */ \ - if ( bli_is_general( strucc ) ) \ - { \ - /* For micro-panels of general matrices, we can call the pack - kernel front-end directly. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ - } \ - else if ( bli_is_herm_or_symm( strucc ) ) \ - { \ - /* Call a helper function for micro-panels of Hermitian/symmetric - matrices. */ \ - PASTEMAC(ch,packm_herm_cxk_rih) \ - ( \ - strucc, \ - diagoffc, \ - uploc, \ - conjc, \ - schema, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ - ); \ - } \ - else /* ( bli_is_triangular( strucc ) ) */ \ - { \ - /* Call a helper function for micro-panels of triangular - matrices. */ \ - PASTEMAC(ch,packm_tri_cxk_rih) \ - ( \ - strucc, \ - diagoffc, \ - diagc, \ - uploc, \ - conjc, \ - schema, \ - invdiag, \ - m_panel, \ - n_panel, \ - m_panel_max, \ - n_panel_max, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, rs_c, cs_c, \ - incc, ldc, \ - p, rs_p, cs_p, \ - ldp, \ - cntx \ - ); \ - } \ -\ -\ - /* If m_panel < m_panel_max, or n_panel < n_panel_max, we would normally - fill the edge region (the bottom m_panel_max - m_panel rows or right- - side n_panel_max - n_panel columns) of the micropanel with zeros. - However, this responsibility has been moved to the packm microkernel. - This change allows experts to use custom kernels that pack to custom - packing formats when the problem size is not a nice multiple of the - register blocksize. */ \ -/* - if ( m_panel != m_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t i = m_panel; \ - dim_t m_edge = m_panel_max - i; \ - dim_t n_edge = n_panel_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*rs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -\ - if ( n_panel != n_panel_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - dim_t j = n_panel; \ - dim_t m_edge = m_panel_max; \ - dim_t n_edge = n_panel_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*cs_p; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ -*/ \ -\ -\ - if ( bli_is_triangular( strucc ) ) \ - { \ - /* If this panel is an edge case in both panel dimension and length, - then it must be a bottom-right corner case. Set the part of the - diagonal that extends into the zero-padded region to identity. - NOTE: This is actually only necessary when packing for trsm, as - it helps prevent NaNs and Infs from creeping into the computation. - However, we set the region to identity for trmm as well. Those - 1.0's end up getting muliplied by the 0.0's in the zero-padded - region of the other matrix, so there is no harm in this. */ \ - if ( m_panel != m_panel_max && \ - n_panel != n_panel_max ) \ - { \ - /* We don't need this case if we aren't supporting trsm. - Why? Because trmm's packm control tree node should be - using k dimension multiples of 1 (kr == 1), which means - there will never be zero padding at the far end of a - micro-panel. */ \ - } \ - } \ -\ -\ -/* - { \ - if ( bli_is_col_packed( schema ) ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_rih: bp copied", m_panel_max, n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - else if ( bli_is_row_packed( schema ) ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_struc_cxk_rih: ap copied", m_panel_max, n_panel_max, \ - ( ctype_r* )p, rs_p, cs_p, "%4.1f", "" ); \ - } \ -*/ \ - \ -\ -} - -INSERT_GENTFUNCCO_BASIC( packm_struc_cxk_rih, packm_cxk_rih ) - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - bool row_stored; \ - bool col_stored; \ - doff_t diagoffc_abs; \ - dim_t j; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ - /* Handle the case where the micro-panel does NOT intersect the - diagonal separately from the case where it does intersect. */ \ - if ( !bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) \ - { \ - /* If the current panel is unstored, we need to make a few - adjustments so we refer to the data where it is actually - stored, also taking conjugation into account. (Note this - implicitly assumes we are operating on a dense panel - within a larger symmetric or Hermitian matrix, since a - general matrix would not contain any unstored region.) */ \ - if ( bli_is_unstored_subpart_n( diagoffc, uploc, m_panel, n_panel ) ) \ - { \ - c = c + diagoffc * ( doff_t )cs_c + \ - -diagoffc * ( doff_t )rs_c; \ - bli_swap_incs( &incc, &ldc ); \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc ); \ - } \ -\ - /* Pack the full panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ - } \ - else /* if ( bli_intersects_diag_n( diagoffc, m_panel, n_panel ) ) */ \ - { \ - ctype_r* restrict p_r = ( ctype_r* )p; \ -\ - ctype* restrict c10; \ - ctype_r* restrict p10; \ - dim_t p10_dim, p10_len; \ - inc_t incc10, ldc10; \ - doff_t diagoffc10; \ - conj_t conjc10; \ -\ - ctype* restrict c12; \ - ctype_r* restrict p12; \ - dim_t p12_dim, p12_len; \ - inc_t incc12, ldc12; \ - doff_t diagoffc12; \ - conj_t conjc12; \ -\ - /* Sanity check. Diagonals should not intersect the short end of - a micro-panel. If they do, then somehow the constraints on - cache blocksizes being a whole multiple of the register - blocksizes was somehow violated. */ \ - if ( ( col_stored && diagoffc < 0 ) || \ - ( row_stored && diagoffc > 0 ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ - diagoffc_abs = bli_abs( diagoffc ); \ -\ - if ( ( row_stored && bli_is_upper( uploc ) ) || \ - ( col_stored && bli_is_lower( uploc ) ) ) \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs; \ - p10 = p_r; \ - c10 = c; \ - incc10 = incc; \ - ldc10 = ldc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - diagoffc12 = diagoffc_abs - j; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - c12 = c12 + diagoffc12 * ( doff_t )cs_c + \ - -diagoffc12 * ( doff_t )rs_c; \ - incc12 = ldc; \ - ldc12 = incc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc12 ); \ - } \ - else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ - ( col_stored && bli_is_upper( uploc ) ) ) */ \ - { \ - p10_dim = panel_dim; \ - p10_len = diagoffc_abs + panel_dim; \ - diagoffc10 = diagoffc; \ - p10 = p_r; \ - c10 = c; \ - c10 = c10 + diagoffc10 * ( doff_t )cs_c + \ - -diagoffc10 * ( doff_t )rs_c; \ - incc10 = ldc; \ - ldc10 = incc; \ - conjc10 = conjc; \ -\ - p12_dim = panel_dim; \ - p12_len = panel_len - p10_len; \ - j = p10_len; \ - p12 = p_r + (j )*ldp; \ - c12 = c + (j )*ldc; \ - incc12 = incc; \ - ldc12 = ldc; \ - conjc12 = conjc; \ -\ - if ( bli_is_hermitian( strucc ) ) \ - bli_toggle_conj( &conjc10 ); \ - } \ -\ - /* Pack to p10. For upper storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc10, \ - schema, \ - p10_dim, \ - panel_dim_max, \ - p10_len, \ - p10_len, \ - kappa, \ - c10, incc10, ldc10, \ - ( ctype* )p10, ldp, \ - cntx \ - ); \ -\ - /* Pack to p12. For lower storage, this includes the unstored - triangle of c11. */ \ - /* NOTE: Since we're only packing partial panels here, we pass in - p1x_len as panel_len_max; otherwise, the packm kernel will zero- - fill the columns up to panel_len_max, which is not what we need - or want to happen. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc12, \ - schema, \ - p12_dim, \ - panel_dim_max, \ - p12_len, \ - p12_len, \ - kappa, \ - c12, incc12, ldc12, \ - ( ctype* )p12, ldp, \ - cntx \ - ); \ -\ - /* Pack the stored triangle of c11 to p11. */ \ - { \ - dim_t j2 = diagoffc_abs; \ - /*ctype_r* restrict p_r = ( ctype_r* )p;*/ \ - ctype* restrict c11 = c + (j2 )*ldc; \ - ctype_r* restrict p11_r = p_r + (j2 )*ldp; \ -\ - PASTEMAC(ch,scal2rihs_mxn_uplo) \ - ( \ - schema, \ - uploc, \ - conjc, \ - panel_dim, \ - kappa, \ - c11, rs_c, cs_c, \ - p11_r, rs_p, cs_p \ - ); \ -\ - /* If we are packing a micro-panel with Hermitian structure, - we must take special care of the diagonal. Now, if kappa - were guaranteed to be unit, all we would need to do is - explicitly zero out the imaginary part of the diagonal of - p11, in case the diagonal of the source matrix contained - garbage (non-zero) imaginary values. HOWEVER, since kappa - can be non-unit, things become a little more complicated. - In general, we must re-apply the kappa scalar to ONLY the - real part of the diagonal of the source matrix and save - the result to the diagonal of p11. */ \ - if ( bli_is_hermitian( strucc ) ) \ - { \ - PASTEMAC3(ch,chr,ch,scal2rihs_mxn_diag) \ - ( \ - schema, \ - panel_dim, \ - panel_dim, \ - kappa, \ - c11, rs_c, cs_c, \ - p11_r, rs_p, cs_p \ - ); \ - } \ -\ -/* - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_r copied", m_panel_max, n_panel_max, \ - p_r + 0*is_p, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_herm_cxk: ap_i copied", m_panel_max, n_panel_max, \ - p_r + 1*is_p, rs_p, cs_p, "%4.1f", "" ); \ -*/ \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_herm_cxk_rih, packm_cxk_rih ) - - - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ) \ -{ \ - /* Pack the panel. */ \ - PASTEMAC(ch,kername) \ - ( \ - conjc, \ - schema, \ - panel_dim, \ - panel_dim_max, \ - panel_len, \ - panel_len_max, \ - kappa, \ - c, incc, ldc, \ - p, ldp, \ - cntx \ - ); \ -\ -\ - /* Tweak the panel according to its triangular structure */ \ - { \ - ctype_r* p_r = ( ctype_r* )p; \ -\ - dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = p_r + (j )*ldp; \ -\ - /* If the diagonal of c is implicitly unit, explicitly set the - the diagonal of the packed panel to kappa. */ \ - if ( bli_is_unit_diag( diagc ) ) \ - { \ - PASTEMAC(ch,setrihs_mxn_diag) \ - ( \ - schema, \ - panel_dim, \ - panel_dim, \ - kappa, \ - p11_r, rs_p, cs_p \ - ); \ - } \ -\ -\ - /* If requested, invert the diagonal of the packed panel. */ \ - if ( invdiag == TRUE ) \ - { \ - /* We don't need this case if we aren't supporting trsm. */ \ - } \ -\ -\ - /* Set the region opposite the diagonal of p to zero. To do this, - we need to reference the "unstored" region on the other side of - the diagonal. This amounts to toggling uploc and then shifting - the diagonal offset to shrink the newly referenced region (by - one diagonal). */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop = uploc; \ -\ - bli_toggle_uplo( &uplop ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop, &diagoffp ); \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffp, \ - BLIS_NONUNIT_DIAG, \ - uplop, \ - m_panel, \ - n_panel, \ - zero_r, \ - p_r, rs_p, cs_p, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC( packm_tri_cxk_rih, packm_cxk_rih ) - diff --git a/frame/1m/packm/bli_packm_struc_cxk_rih.h b/frame/1m/packm/bli_packm_struc_cxk_rih.h deleted file mode 100644 index 0af4d33e82..0000000000 --- a/frame/1m/packm/bli_packm_struc_cxk_rih.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffp, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t is_p, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_struc_cxk_rih ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_herm_cxk_rih ) - - - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - struc_t strucc, \ - doff_t diagoffc, \ - diag_t diagc, \ - uplo_t uploc, \ - conj_t conjc, \ - pack_t schema, \ - bool invdiag, \ - dim_t m_panel, \ - dim_t n_panel, \ - dim_t m_panel_max, \ - dim_t n_panel_max, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* restrict kappa, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - inc_t incc, inc_t ldc, \ - ctype* restrict p, inc_t rs_p, inc_t cs_p, \ - inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_tri_cxk_rih ) - diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 86d0692163..6aedb8bd6b 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -736,7 +736,21 @@ void bli_zgemv_unf_var1 switch (id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + /* + Assign the AVX2 based kernel function pointers for + DOTXF, SCAL2Vand corresponding fusing + factor of DOTXF kernel + */ + + dotxf_kr_ptr = bli_zdotxf_zen_int_8_avx512; + b_fuse = 8; + + scal2v_kr_ptr = bli_zscal2v_zen_int; + break; +#endif case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN3: diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c index d8a0c8911b..cbf545d642 100644 --- a/frame/2/gemv/bli_gemv_unf_var2_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -311,6 +311,7 @@ void bli_dgemv_unf_var2 switch (id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) /* @@ -319,8 +320,8 @@ void bli_dgemv_unf_var2 factor of DAXPYF kernel */ - axpyf_kr_ptr = bli_daxpyf_zen_int_8; - b_fuse = 8; + axpyf_kr_ptr = bli_daxpyf_zen_int_avx512; + b_fuse = 32; scalv_kr_ptr = bli_dscalv_zen_int_avx512; @@ -690,10 +691,11 @@ void bli_zgemv_unf_var2 Function pointer declaration for the functions that will be used by this API */ - zaxpyf_ker_ft axpyf_kr_ptr; // ZAXPYF + zaxpyf_ker_ft axpyf_kr_ptr; // ZAXPYF zscal2v_ker_ft scal2v_kr_ptr; // ZSCAL2V - zscalv_ker_ft scalv_kr_ptr; // ZSCALV - zcopyv_ker_ft copyv_kr_ptr; // ZCOPYV + zscalv_ker_ft scalv_kr_ptr; // ZSCALV + zcopyv_ker_ft copyv_kr_ptr; // ZCOPYV + zsetv_ker_ft setv_kr_ptr; // ZSETV /* Boolean to check if the y has been packed @@ -703,7 +705,21 @@ void bli_zgemv_unf_var2 switch (id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + axpyf_kr_ptr = bli_zaxpyf_zen_int_8_avx512; + b_fuse = 8; + + scal2v_kr_ptr = bli_zscal2v_zen_int; + + scalv_kr_ptr = bli_zscalv_zen_int; + + copyv_kr_ptr = bli_zcopyv_zen_int; + + setv_kr_ptr = bli_zsetv_zen_int_avx512; + break; +#endif case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN3: @@ -723,6 +739,7 @@ void bli_zgemv_unf_var2 copyv_kr_ptr = bli_zcopyv_zen_int; + setv_kr_ptr = bli_zsetv_zen_int; break; default: // For non-Zen architectures, query the context if it is NULL @@ -741,6 +758,8 @@ void bli_zgemv_unf_var2 scalv_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_SCALV_KER, cntx); copyv_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_COPYV_KER, cntx); + + setv_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_SETV_KER, cntx); } /* @@ -814,11 +833,26 @@ void bli_zgemv_unf_var2 } else { + /* + Invoke the ZSETV function using the function + pointer only when beta is 0. + */ + if(PASTEMAC(z, eq0)(*beta)) + { + setv_kr_ptr + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y_buf, buf_incy, + cntx + ); + } /* Invoke the ZSCALV function using the function - pointer only when alpha is not 1. + pointer only when beta is not 1. */ - if(!PASTEMAC(z, eq1)(*beta)) + else if(!PASTEMAC(z, eq1)(*beta)) { scalv_kr_ptr ( diff --git a/frame/2/trsv/bli_trsv_unf_var1_amd.c b/frame/2/trsv/bli_trsv_unf_var1_amd.c index 6714e79a08..de460164d6 100644 --- a/frame/2/trsv/bli_trsv_unf_var1_amd.c +++ b/frame/2/trsv/bli_trsv_unf_var1_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -77,15 +77,19 @@ void PASTEMAC(ch,varname) \ conj_t conja; \ \ /* x = alpha * x; */ \ - PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - m, \ - alpha, \ - x, incx, \ - cntx, \ - NULL \ - ); \ + /* Avoid alpha scaling when alpha is one */ \ + if ( !PASTEMAC(ch, eq1)(*alpha) ) \ + { \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ + } \ \ if ( bli_does_notrans( transa ) ) \ { \ @@ -298,15 +302,28 @@ void bli_dtrsv_unf_var1 // This function is invoked on all architectures including 'generic'. // Non-AVX2+FMA3 platforms will use the kernels derived from the context. if (bli_cpuid_is_avx2fma3_supported() == TRUE) { - kfp_df = bli_ddotxf_zen_int_8; - b_fuse = 8; + arch_t id = bli_arch_query_id(); + switch (id) + { +#if defined(BLIS_KERNELS_ZEN4) + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + kfp_df = bli_ddotxf_zen_int_avx512; + b_fuse = 8; + break; +#endif + default: + kfp_df = bli_ddotxf_zen_int_8; + b_fuse = 8; + break; + } } else { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(d,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(d,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); } /* We reduce all of the possible cases down to just lower/upper. */ diff --git a/frame/2/trsv/bli_trsv_unf_var2_amd.c b/frame/2/trsv/bli_trsv_unf_var2_amd.c index d04e1b9aca..b943da2a20 100644 --- a/frame/2/trsv/bli_trsv_unf_var2_amd.c +++ b/frame/2/trsv/bli_trsv_unf_var2_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -295,13 +295,60 @@ void bli_dtrsv_unf_var2 conja = bli_extract_conj( transa ); - PASTECH(d,axpyf_ker_ft) kfp_af; + PASTECH(d,axpyf_ker_ft) kfp_af = NULL; // This function is invoked on all architectures including 'generic'. // Non-AVX2+FMA3 platforms will use the kernels derived from the context. if (bli_cpuid_is_avx2fma3_supported() == TRUE) { - kfp_af = bli_daxpyf_zen_int_16x4; - b_fuse = 4; + arch_t id = bli_arch_query_id(); + switch (id) + { +#if defined(BLIS_KERNELS_ZEN4) + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + { +#ifdef BLIS_ENABLE_OPENMP + // For sizes < 800 ST kernels are performing better. + if (m > 800) + { + rntm_t rntm; + bli_rntm_init_from_global(&rntm); + dim_t n_threads = bli_rntm_num_threads(&rntm); + // If NT == 1, don't use MT kernel. + if ( n_threads > 1 ) + { + kfp_af = bli_daxpyf_zen_int32_avx512_mt; + b_fuse = 32; + } + } +#endif + if ( kfp_af == NULL ) + { + // AVX2 kernel performs better for small sizes on Genoa + if ( id == BLIS_ARCH_ZEN4 && m < 380 ) + { + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + } + else if ( m < 2500 ) + { + kfp_af = bli_daxpyf_zen_int8_avx512; + b_fuse = 8; + } + else + { + kfp_af = bli_daxpyf_zen_int12_avx512; + b_fuse = 12; + } + } + break; + } +#endif + default: + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + break; + } } else { @@ -668,15 +715,19 @@ void bli_ztrsv_unf_var2 if( cntx == NULL ) cntx = bli_gks_query_cntx(); /* x = alpha * x; */ - PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); + /* Avoid alpha scaling when alpha is one */ + if ( !PASTEMAC(z, eq1)(*alpha) ) + { + PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + } if( bli_does_notrans( transa ) ) { diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index 3e1f18de18..69cffa9601 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -46,24 +46,15 @@ #include "bli_l3_direct.h" #include "bli_l3_prune.h" #include "bli_l3_packm.h" +#include "bli_l3_schema.h" -// Prototype object APIs (expert and non-expert). -#include "bli_oapi_ex.h" +// Prototype object APIs (basic and expert). #include "bli_l3_oapi.h" -#include "bli_xapi_undef.h" +#include "bli_l3_oapi_ex.h" -#include "bli_oapi_ba.h" -#include "bli_l3_oapi.h" -#include "bli_xapi_undef.h" - -// Prototype typed APIs (expert and non-expert). -#include "bli_tapi_ex.h" -#include "bli_l3_tapi.h" -#include "bli_xapi_undef.h" - -#include "bli_tapi_ba.h" +// Prototype typed APIs (basic and expert). #include "bli_l3_tapi.h" -#include "bli_xapi_undef.h" +#include "bli_l3_tapi_ex.h" // Define function types for small/unpacked handlers/kernels. #include "bli_l3_sup_oft.h" diff --git a/frame/3/bli_l3_blocksize.c b/frame/3/bli_l3_blocksize.c index 595b5410ab..51844eebe5 100644 --- a/frame/3/bli_l3_blocksize.c +++ b/frame/3/bli_l3_blocksize.c @@ -1,11 +1,11 @@ - /* +/* BLIS An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_check.c b/frame/3/bli_l3_check.c index 43ba867283..284e733bd3 100644 --- a/frame/3/bli_l3_check.c +++ b/frame/3/bli_l3_check.c @@ -99,7 +99,7 @@ void bli_hemm_check { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); @@ -249,7 +249,7 @@ void bli_syr2k_check bli_check_error_code( e_val ); } -void bli_trmm_check +void bli_trmm3_check ( side_t side, obj_t* alpha, @@ -262,7 +262,7 @@ void bli_trmm_check { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); @@ -272,22 +272,41 @@ void bli_trmm_check bli_check_error_code( e_val ); } +void bli_trmm_check + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx + ) +{ + err_t e_val; + + // Perform checks common to hemm/symm/trmm/trsm. + + bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + + // Check object structure. + + e_val = bli_check_triangular_object( a ); + bli_check_error_code( e_val ); +} + void bli_trsm_check ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, - obj_t* beta, - obj_t* c, cntx_t* cntx ) { err_t e_val; - // Perform checks common to hemm/symm. + // Perform checks common to hemm/symm/trmm/trsm. - bli_hemm_basic_check( side, alpha, a, b, beta, c, cntx ); + bli_hemm_basic_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); // Check object structure. diff --git a/frame/3/bli_l3_check.h b/frame/3/bli_l3_check.h index b2216c34bd..c600d60b9a 100644 --- a/frame/3/bli_l3_check.h +++ b/frame/3/bli_l3_check.h @@ -72,8 +72,7 @@ void PASTEMAC(opname,_check) \ GENPROT( hemm ) GENPROT( symm ) -GENPROT( trmm ) -GENPROT( trsm ) +GENPROT( trmm3 ) #undef GENPROT @@ -92,6 +91,22 @@ GENPROT( herk ) GENPROT( syrk ) +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC(opname,_check) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx \ + ); + +GENPROT( trmm ) +GENPROT( trsm ) + + // ----------------------------------------------------------------------------- void bli_gemm_basic_check diff --git a/frame/ind/bli_l3_ind.c b/frame/3/bli_l3_ind.c similarity index 68% rename from frame/ind/bli_l3_ind.c rename to frame/3/bli_l3_ind.c index b7cb0fcdee..ae8b6e0b52 100644 --- a/frame/ind/bli_l3_ind.c +++ b/frame/3/bli_l3_ind.c @@ -35,23 +35,13 @@ #include "blis.h" -static void_fp bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = +// This array tracks whether a particular operation is implemented for each of +// the induced methods. +static bool bli_l3_ind_oper_impl[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = { - /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm gemmt*/ -/* 3mh */ { bli_gemm3mh, bli_hemm3mh, bli_herk3mh, bli_her2k3mh, bli_symm3mh, - bli_syrk3mh, bli_syr2k3mh, bli_trmm33mh, NULL, NULL , NULL }, -/* 3m1 */ { bli_gemm3m1, bli_hemm3m1, bli_herk3m1, bli_her2k3m1, bli_symm3m1, - bli_syrk3m1, bli_syr2k3m1, bli_trmm33m1, bli_trmm3m1, bli_trsm3m1 , NULL }, -/* 4mh */ { bli_gemm4mh, bli_hemm4mh, bli_herk4mh, bli_her2k4mh, bli_symm4mh, - bli_syrk4mh, bli_syr2k4mh, bli_trmm34mh, NULL, NULL , NULL }, -/* 4mb */ { bli_gemm4mb, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL , NULL }, -/* 4m1 */ { bli_gemm4m1, bli_hemm4m1, bli_herk4m1, bli_her2k4m1, bli_symm4m1, - bli_syrk4m1, bli_syr2k4m1, bli_trmm34m1, bli_trmm4m1, bli_trsm4m1 , NULL }, -/* 1m */ { bli_gemm1m, bli_hemm1m, bli_herk1m, bli_her2k1m, bli_symm1m, - bli_syrk1m, bli_syr2k1m, bli_trmm31m, bli_trmm1m, bli_trsm1m , NULL }, -/* nat */ { bli_gemmnat, bli_hemmnat, bli_herknat, bli_her2knat, bli_symmnat, - bli_syrknat, bli_syr2knat, bli_trmm3nat, bli_trmmnat, bli_trsmnat , bli_gemmtnat }, + /* gemm gemmt hemm herk her2k symm syrk syr2k trmm3 trmm trsm */ +/* 1m */ { TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE }, +/* nat */ { TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE } }; // @@ -64,21 +54,11 @@ static void_fp bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = static BLIS_THREAD_LOCAL bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = { - /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm */ + /* gemm gemmt hemm herk her2k symm syrk syr2k trmm3 trmm trsm */ /* c z */ -/* 3mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, +/* 1m */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 3m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4mb */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 4m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 1m */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, - {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* nat */ { {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, +/* nat */ { {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE}, {TRUE,TRUE} }, }; @@ -87,16 +67,14 @@ bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = #undef GENFUNC #define GENFUNC( opname, optype ) \ \ -void_fp PASTEMAC(opname,ind_get_avail)( num_t dt ) \ +ind_t PASTEMAC(opname,ind_find_avail)( num_t dt ) \ { \ - return bli_ind_oper_get_avail( optype, dt ); \ + return bli_l3_ind_oper_find_avail( optype, dt ); \ } -/* -bool PASTEMAC(opname,ind_has_avail)( num_t dt ) -{ - return bli_ind_oper_has_avail( optype, dt ); -} -*/ +//bool PASTEMAC(opname,ind_has_avail)( num_t dt ) +//{ +// return bli_ind_oper_has_avail( optype, dt ); +//} GENFUNC( gemm, BLIS_GEMM ) GENFUNC( gemmt, BLIS_GEMMT ) @@ -115,16 +93,16 @@ GENFUNC( trsm, BLIS_TRSM ) #if 0 bool bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ) { - void_fp func; - bool stat; + bool enabled; + bool stat; // If the datatype is real, it is never available. if ( !bli_is_complex( dt ) ) return FALSE; - func = bli_l3_ind_oper_get_func( oper, method ); - stat = bli_l3_ind_oper_get_enable( oper, method, dt ); + enabled = bli_l3_ind_oper_is_impl( oper, method ); + stat = bli_l3_ind_oper_get_enable( oper, method, dt ); - return ( func != NULL && stat == TRUE ); + return ( enabled == TRUE && stat == TRUE ); } #endif @@ -147,11 +125,11 @@ ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ) // current operation and datatype. for ( im = 0; im < BLIS_NUM_IND_METHODS; ++im ) { - void_fp func = bli_l3_ind_oper_get_func( oper, im ); - bool stat = bli_l3_ind_oper_get_enable( oper, im, dt ); + bool enabled = bli_l3_ind_oper_is_impl( oper, im ); + bool stat = bli_l3_ind_oper_get_enable( oper, im, dt ); - if ( func != NULL && - stat == TRUE ) return im; + if ( enabled == TRUE && + stat == TRUE ) return im; } // This return statement should never execute since the native index @@ -257,8 +235,7 @@ bool bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ) // ----------------------------------------------------------------------------- -void_fp bli_l3_ind_oper_get_func( opid_t oper, ind_t method ) +bool bli_l3_ind_oper_is_impl( opid_t oper, ind_t method ) { - return bli_l3_ind_oper_fp[ method ][ oper ]; + return bli_l3_ind_oper_impl[ method ][ oper ]; } - diff --git a/frame/ind/bli_l3_ind.h b/frame/3/bli_l3_ind.h similarity index 96% rename from frame/ind/bli_l3_ind.h rename to frame/3/bli_l3_ind.h index 3d035cf2f9..87499428d5 100644 --- a/frame/ind/bli_l3_ind.h +++ b/frame/3/bli_l3_ind.h @@ -41,7 +41,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void_fp PASTEMAC(opname,ind_get_avail)( num_t dt ); +ind_t PASTEMAC(opname,ind_find_avail)( num_t dt ); /*bool PASTEMAC(opname,ind_has_avail)( num_t dt ); */ GENPROT( gemm ) @@ -70,7 +70,7 @@ void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool status ); void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool status ); bool bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ); -void_fp bli_l3_ind_oper_get_func( opid_t oper, ind_t method ); +bool bli_l3_ind_oper_is_impl( opid_t oper, ind_t method ); #endif diff --git a/frame/ind/ukernels/bli_l3_ind_ukr.h b/frame/3/bli_l3_ind_ukr.h similarity index 84% rename from frame/ind/ukernels/bli_l3_ind_ukr.h rename to frame/3/bli_l3_ind_ukr.h index 53cb0b6f88..f73a6ad907 100644 --- a/frame/ind/ukernels/bli_l3_ind_ukr.h +++ b/frame/3/bli_l3_ind_ukr.h @@ -53,11 +53,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( gemm3mh_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm3m1_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4mh_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4mb_ukr_name ) -INSERT_GENTPROT_BASIC0( gemm4m1_ukr_name ) INSERT_GENTPROT_BASIC0( gemm1m_ukr_name ) @@ -77,10 +72,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( gemmtrsm3m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm3m1_u_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm4m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( gemmtrsm4m1_u_ukr_name ) INSERT_GENTPROT_BASIC0( gemmtrsm1m_l_ukr_name ) INSERT_GENTPROT_BASIC0( gemmtrsm1m_u_ukr_name ) @@ -97,10 +88,6 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); -INSERT_GENTPROT_BASIC0( trsm3m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm3m1_u_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm4m1_l_ukr_name ) -INSERT_GENTPROT_BASIC0( trsm4m1_u_ukr_name ) INSERT_GENTPROT_BASIC0( trsm1m_l_ukr_name ) INSERT_GENTPROT_BASIC0( trsm1m_u_ukr_name ) diff --git a/frame/3/bli_l3_oapi.c b/frame/3/bli_l3_oapi.c index 07054968eb..b5507a31cb 100644 --- a/frame/3/bli_l3_oapi.c +++ b/frame/3/bli_l3_oapi.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -33,217 +33,31 @@ */ -// Guard the function definitions so that they are only compiled when -// #included from files that define the object API macros. -#ifdef BLIS_ENABLE_OAPI +#include "blis.h" // -// Define object-based interfaces. +// Define object-based interfaces (basic). // #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* If C has a zero dimension, return early. */ \ - if ( bli_obj_has_zero_dim( c ) ) {\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - }\ -\ - /* if alpha or A or B has a zero dimension, \ - scale C by beta and return early. */ \ - if ( bli_obj_equals( alpha, &BLIS_ZERO ) || \ - bli_obj_has_zero_dim( a ) || \ - bli_obj_has_zero_dim( b ) ) \ - {\ - bli_scalm( beta, c ); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return;\ - }\ -\ - /* If the rntm is non-NULL, it may indicate that we should forgo sup - handling altogether. */ \ - bool enable_sup = TRUE; \ - if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ -\ - if ( enable_sup ) \ - { \ - /* Execute the small/unpacked oapi handler. If it finds that the problem - does not fall within the thresholds that define "small", or for some - other reason decides not to use the small/unpacked implementation, - the function returns with BLIS_FAILURE, which causes execution to - proceed towards the conventional implementation. */ \ - err_t result = PASTEMAC(opname,sup)( alpha, a, b, beta, c, cntx, rntm ); \ - if ( result == BLIS_SUCCESS ) {\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - } \ - } \ -\ - /* Only proceed with an induced method if each of the operands have a - complex storage datatype. NOTE: Allowing precisions to vary while - using 1m, which is what we do here, is unique to gemm; other level-3 - operations use 1m only if all storage datatypes are equal (and they - ignore the computation precision). If any operands are real, skip the - induced method chooser function and proceed directly with native - execution. */ \ - if ( bli_obj_is_complex( c ) && \ - bli_obj_is_complex( a ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( alpha, a, b, beta, c, NULL, NULL ); \ } GENFRONT( gemm ) - - -#undef GENFRONT -#define GENFRONT( opname ) \ -\ -void PASTEMAC(opname,EX_SUF) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c \ - BLIS_OAPI_EX_PARAMS \ - ) \ -{ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* If C has a zero dimension, return early. */ \ - if ( bli_obj_has_zero_dim( c ) ) {\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - }\ -\ - /* if alpha or A or B has a zero dimension, \ - scale C by beta and return early. */ \ - if ( bli_obj_equals( alpha, &BLIS_ZERO ) || \ - bli_obj_has_zero_dim( a ) || \ - bli_obj_has_zero_dim( b ) ) \ - {\ - bli_scalm( beta, c ); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return;\ - }\ -\ - /* If the rntm is non-NULL, it may indicate that we should forgo sup - handling altogether. */ \ - bool enable_sup = TRUE; \ - if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ -\ - if ( enable_sup ) \ - { \ - /* Execute the small/unpacked oapi handler. If it finds that the problem - does not fall within the thresholds that define "small", or for some - other reason decides not to use the small/unpacked implementation, - the function returns with BLIS_FAILURE, which causes execution to - proceed towards the conventional implementation. */ \ - err_t result = PASTEMAC(opname,sup)( alpha, a, b, beta, c, cntx, rntm ); \ - if ( result == BLIS_SUCCESS ) {\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - } \ - } \ -\ - /* Only proceed with an induced method if each of the operands have a - complex storage datatype. NOTE: Allowing precisions to vary while - using 1m, which is what we do here, is unique to gemm; other level-3 - operations use 1m only if all storage datatypes are equal (and they - ignore the computation precision). If any operands are real, skip the - induced method chooser function and proceed directly with native - execution. */ \ - if ( bli_obj_is_complex( c ) && \ - bli_obj_is_complex( a ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* GEMMT Todo: Currently we support only native implementation - for complex datatypes.*/ \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ -} GENFRONT( gemmt ) - -#undef GENFRONT -#define GENFRONT( opname ) \ -\ -void PASTEMAC(opname,EX_SUF) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c \ - BLIS_OAPI_EX_PARAMS \ - ) \ -{ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if each of the operands have a - complex storage datatype. NOTE: Allowing precisions to vary while - using 1m, which is what we do here, is unique to gemm; other level-3 - operations use 1m only if all storage datatypes are equal (and they - ignore the computation precision). If any operands are real, skip the - induced method chooser function and proceed directly with native - execution. */ \ - if ( bli_obj_is_complex( c ) && \ - bli_obj_is_complex( a ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - } \ -} - GENFRONT( her2k ) GENFRONT( syr2k ) @@ -251,7 +65,7 @@ GENFRONT( syr2k ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ @@ -259,34 +73,11 @@ void PASTEMAC(opname,EX_SUF) \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ - bli_obj_dt( b ) == bli_obj_dt( c ) && \ - bli_obj_is_complex( c ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( side, alpha, a, b, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, beta, c, cntx, rntm ); \ - } \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2)\ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( side, alpha, a, b, beta, c, NULL, NULL ); \ } GENFRONT( hemm ) @@ -297,157 +88,39 @@ GENFRONT( trmm3 ) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* If C has a zero dimension, return early. */ \ - if ( bli_obj_has_zero_dim( c ) ) {\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - } \ -\ - /* If alpha or A or B has a zero dimension, \ - * scale C by beta and return early. */ \ -\ - if( bli_obj_equals( alpha, &BLIS_ZERO ) || \ - bli_obj_has_zero_dim( a ) ) \ - { \ - bli_scalm( beta, c ); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - } \ -\ - /* If the rntm is non-NULL, it may indicate that we should forgo SUP handling altogether. */ \ - bool enable_sup = TRUE; \ - if( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ -\ - if( enable_sup ) \ - { \ - /* Execute the small/unpacked oapi handler. - * If it finds that the problem does not fall within the - * thresholds that define "small", or for some other reason - * decides not to use the small/unpacked implementation, - * the function returns with BLIS_FAILURE, which causes excution - * to proceed forward towards conventional implementation, */ \ -\ - err_t result = PASTEMAC(opname, sup) ( alpha, a, beta, c, cntx, rntm ); \ - if( result == BLIS_SUCCESS ) { \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ - return; \ - } \ - } \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ - bli_obj_is_complex( c ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, beta, c, cntx, rntm ); \ - } \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( alpha, a, beta, c, NULL, NULL ); \ } +GENFRONT( herk ) GENFRONT( syrk ) -#undef GENFRONT -#define GENFRONT( opname ) \ -\ -void PASTEMAC(opname,EX_SUF) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c \ - BLIS_OAPI_EX_PARAMS \ - ) \ -{ \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ - bli_obj_is_complex( c ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( alpha, a, beta, c, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( alpha, a, beta, c, cntx, rntm ); \ - } \ -} -GENFRONT(herk) #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ obj_t* a, \ obj_t* b \ - BLIS_OAPI_EX_PARAMS \ ) \ { \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - BLIS_OAPI_EX_DECLS \ -\ - /* Only proceed with an induced method if all operands have the same - (complex) datatype. If any datatypes differ, skip the induced method - chooser function and proceed directly with native execution, which is - where mixed datatype support will be implemented (if at all). */ \ - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && \ - bli_obj_is_complex( b ) ) \ - { \ - /* Invoke the operation's "ind" function--its induced method front-end. - For complex problems, it calls the highest priority induced method - that is available (ie: implemented and enabled), and if none are - enabled, it calls native execution. (For real problems, it calls - the operation's native execution interface.) */ \ - PASTEMAC(opname,ind)( side, alpha, a, b, cntx, rntm ); \ - } \ - else \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - } \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC(opname,_ex)( side, alpha, a, b, NULL, NULL ); \ } GENFRONT( trmm ) GENFRONT( trsm ) - -#endif - diff --git a/frame/3/bli_l3_oapi.h b/frame/3/bli_l3_oapi.h index 5375d5708c..9100c93e7c 100644 --- a/frame/3/bli_l3_oapi.h +++ b/frame/3/bli_l3_oapi.h @@ -35,20 +35,19 @@ // -// Prototype object-based interfaces. +// Prototype object-based interfaces (basic). // #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( gemm ) @@ -60,7 +59,7 @@ GENPROT( syr2k ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ @@ -68,7 +67,6 @@ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ obj_t* b, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( hemm ) @@ -79,13 +77,12 @@ GENPROT( trmm3 ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ obj_t* beta, \ obj_t* c \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( herk ) @@ -95,13 +92,12 @@ GENPROT( syrk ) #undef GENPROT #define GENPROT( opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ side_t side, \ obj_t* alpha, \ obj_t* a, \ obj_t* b \ - BLIS_OAPI_EX_PARAMS \ ); GENPROT( trmm ) diff --git a/frame/3/bli_l3_oapi_ex.c b/frame/3/bli_l3_oapi_ex.c index 76f4fe16ab..a51270c4fe 100644 --- a/frame/3/bli_l3_oapi_ex.c +++ b/frame/3/bli_l3_oapi_ex.c @@ -4,7 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,13 +35,519 @@ #include "blis.h" -// Include cpp macros that instantiate the API definition templates as -// having expert parameters. -#include "bli_oapi_ex.h" +// +// Define object-based interfaces (expert). +// -// Define the macro protecting the object API definitions. -#define BLIS_ENABLE_OAPI +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* If C has a zero dimension, return early. */ \ + if ( bli_obj_has_zero_dim( c ) ) {\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + }\ +\ + /* if alpha or A or B has a zero dimension, \ + scale C by beta and return early. */ \ + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || \ + bli_obj_has_zero_dim( a ) || \ + bli_obj_has_zero_dim( b ) ) \ + {\ + bli_scalm( beta, c ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return;\ + }\ +\ + /* If the rntm is non-NULL, it may indicate that we should forgo sup + handling altogether. */ \ + bool enable_sup = TRUE; \ + if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ +\ + if ( enable_sup ) \ + { \ + /* Execute the small/unpacked oapi handler. If it finds that the problem + does not fall within the thresholds that define "small", or for some + other reason decides not to use the small/unpacked implementation, + the function returns with BLIS_FAILURE, which causes execution to + proceed towards the conventional implementation. */ \ + err_t result = PASTEMAC(opname,sup)( alpha, a, b, beta, c, cntx, rntm ); \ + if ( result == BLIS_SUCCESS ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ + } \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If each matrix operand has a complex storage datatype, try to get an + induced method (if one is available and enabled). NOTE: Allowing + precisions to vary while using 1m, which is what we do here, is unique + to gemm; other level-3 operations use 1m only if all storage datatypes + are equal (and they ignore the computation precision). */ \ + if ( bli_obj_is_complex( c ) && \ + bli_obj_is_complex( a ) && \ + bli_obj_is_complex( b ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( alpha, a, b, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( alpha, a, b, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} -// Include the object API definitions here. -#include "bli_l3_oapi.c" +// If a sandbox was enabled, we forgo defining bli_gemm_ex() since it will be +// defined in the sandbox environment. +#ifndef BLIS_ENABLE_SANDBOX +GENFRONT( gemm ) +#endif + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* If C has a zero dimension, return early. */ \ + if ( bli_obj_has_zero_dim( c ) ) {\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + }\ +\ + /* if alpha or A or B has a zero dimension, \ + scale C by beta and return early. */ \ + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || \ + bli_obj_has_zero_dim( a ) || \ + bli_obj_has_zero_dim( b ) ) \ + {\ + bli_scalm( beta, c ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return;\ + }\ +\ + /* If the rntm is non-NULL, it may indicate that we should forgo sup + handling altogether. */ \ + bool enable_sup = TRUE; \ + if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ +\ + if ( enable_sup ) \ + { \ + /* Execute the small/unpacked oapi handler. If it finds that the problem + does not fall within the thresholds that define "small", or for some + other reason decides not to use the small/unpacked implementation, + the function returns with BLIS_FAILURE, which causes execution to + proceed towards the conventional implementation. */ \ + err_t result = PASTEMAC(opname,sup)( alpha, a, b, beta, c, cntx, rntm ); \ + if ( result == BLIS_SUCCESS ) \ + {\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ + } \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ + bli_obj_dt( b ) == bli_obj_dt( c ) && \ + bli_obj_is_complex( c ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( alpha, a, b, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( alpha, a, b, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( gemmt ) + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ + bli_obj_dt( b ) == bli_obj_dt( c ) && \ + bli_obj_is_complex( c ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( alpha, a, b, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( alpha, a, b, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( her2k ) +GENFRONT( syr2k ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ + bli_obj_dt( b ) == bli_obj_dt( c ) && \ + bli_obj_is_complex( c ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( side, alpha, a, b, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( side, alpha, a, b, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( hemm ) +GENFRONT( symm ) +GENFRONT( trmm3 ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* If C has a zero dimension, return early. */ \ + if ( bli_obj_has_zero_dim( c ) ) {\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ +\ + /* If alpha or A or B has a zero dimension, \ + scale C by beta and return early. */ \ +\ + if( bli_obj_equals( alpha, &BLIS_ZERO ) || \ + bli_obj_has_zero_dim( a ) ) \ + { \ + bli_scalm( beta, c ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ +\ + /* If the rntm is non-NULL, it may indicate that we should forgo SUP handling altogether. */ \ + bool enable_sup = TRUE; \ + if( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ +\ + if( enable_sup ) \ + { \ + /* Execute the small/unpacked oapi handler. + If it finds that the problem does not fall within the + thresholds that define "small", or for some other reason + decides not to use the small/unpacked implementation, + the function returns with BLIS_FAILURE, which causes excution + to proceed forward towards conventional implementation, */ \ +\ + err_t result = PASTEMAC(opname, sup) ( alpha, a, beta, c, cntx, rntm ); \ + if( result == BLIS_SUCCESS ) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + } \ + } \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ + bli_obj_is_complex( c ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( alpha, a, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( alpha, a, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( syrk ) + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( c ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( c ) && \ + bli_obj_is_complex( c ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( alpha, a, beta, c, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( alpha, a, beta, c, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( herk ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Default to using native execution. */ \ + num_t dt = bli_obj_dt( b ); \ + ind_t im = BLIS_NAT; \ +\ + /* If all matrix operands are complex and of the same storage datatype, try + to get an induced method (if one is available and enabled). */ \ + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && \ + bli_obj_is_complex( b ) ) \ + { \ + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ \ + im = PASTEMAC(opname,ind_find_avail)( dt ); \ + } \ +\ + /* If necessary, obtain a valid context from the gks using the induced + method id determined above. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_ind_cntx( im, dt ); \ +\ + /* Check the operands. */ \ + if ( bli_error_checking_is_enabled() ) \ + PASTEMAC(opname,_check)( side, alpha, a, b, cntx ); \ +\ + /* Invoke the operation's front-end and request the default control tree. */ \ + PASTEMAC(opname,_front)( side, alpha, a, b, cntx, rntm, NULL ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ +} + +GENFRONT( trmm ) +GENFRONT( trsm ) diff --git a/sandbox/gemmlike/bli_gemmnat.c b/frame/3/bli_l3_oapi_ex.h similarity index 55% rename from sandbox/gemmlike/bli_gemmnat.c rename to frame/3/bli_l3_oapi_ex.h index 37fb701859..0b7cf0981c 100644 --- a/sandbox/gemmlike/bli_gemmnat.c +++ b/frame/3/bli_l3_oapi_ex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,21 +33,15 @@ */ -// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the -// entry point to any sandbox implementation. -// NOTE: This function is implemented identically to the function that it -// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are -// forgoing the option of customizing the implementations that underlie -// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox -// directory, however, will be included in the BLIS. +// +// Prototype object-based interfaces (expert). +// -#include "blis.h" - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ +#undef GENPROT +#define GENPROT( opname ) \ \ -void PASTEMAC(opname,imeth) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -55,34 +50,64 @@ void PASTEMAC(opname,imeth) \ obj_t* c, \ cntx_t* cntx, \ rntm_t* rntm \ - ) \ -{ \ -\ - /* A switch to easily toggle whether we use the sandbox implementation - of bls_gemm() as the implementation for bli_gemm(). (This allows for - easy testing of bls_gemm() via the testsuite.) */ \ - if ( 1 ) \ - { \ - bls_gemm_ex( alpha, a, b, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - bli_init_once(); \ + ); + +GENPROT( gemm ) +GENPROT( gemmt ) +GENPROT( her2k ) +GENPROT( syr2k ) + + +#undef GENPROT +#define GENPROT( opname ) \ \ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( hemm ) +GENPROT( symm ) +GENPROT( trmm3 ) + + +#undef GENPROT +#define GENPROT( opname ) \ \ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( herk ) +GENPROT( syrk ) + + +#undef GENPROT +#define GENPROT( opname ) \ \ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( gemm, gemm, nat ) +BLIS_EXPORT_BLIS void PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( trmm ) +GENPROT( trsm ) + diff --git a/frame/3/bli_l3_schema.c b/frame/3/bli_l3_schema.c new file mode 100644 index 0000000000..bde30c5277 --- /dev/null +++ b/frame/3/bli_l3_schema.c @@ -0,0 +1,80 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_l3_set_schemas + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx + ) +{ + // Begin with pack schemas for native execution. + pack_t schema_a = BLIS_PACKED_ROW_PANELS; + pack_t schema_b = BLIS_PACKED_COL_PANELS; + + // When executing the 1m method, choose the appropriate pack schemas based + // on the microkernel preference encoded within the current cntx_t (which + // was presumably returned by the gks). + if ( bli_cntx_method( cntx ) == BLIS_1M ) + { + num_t dt = bli_obj_domain( c ) | bli_obj_comp_prec( c ); + + // Note that bli_cntx_l3_vir_ukr_prefers_cols_dt() will use the real + // projection of dt to query the preference of the corresponding native + // real-domain microkernel. This is what ultimately determines which + // variant of 1m is applicable. + if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) + { + schema_a = BLIS_PACKED_ROW_PANELS_1E; + schema_b = BLIS_PACKED_COL_PANELS_1R; + } + else + { + schema_a = BLIS_PACKED_ROW_PANELS_1R; + schema_b = BLIS_PACKED_COL_PANELS_1E; + } + } + + // Embed the schemas into the objects for A and B. This is a sort of hack + // for communicating the desired pack schemas to bli_gemm_cntl_create() + // (via bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows + // us to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + bli_obj_set_pack_schema( schema_a, a ); + bli_obj_set_pack_schema( schema_b, b ); +} + diff --git a/frame/1m/packm/bli_packm_cxk_4mi.h b/frame/3/bli_l3_schema.h similarity index 76% rename from frame/1m/packm/bli_packm_cxk_4mi.h rename to frame/3/bli_l3_schema.h index 244f2d045e..c6a12ce520 100644 --- a/frame/1m/packm/bli_packm_cxk_4mi.h +++ b/frame/3/bli_l3_schema.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,22 +32,10 @@ */ - -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ +void bli_l3_set_schemas + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx ); - -INSERT_GENTPROTCO_BASIC0( packm_cxk_4mi ) - diff --git a/frame/3/bli_l3_smart_threading.c b/frame/3/bli_l3_smart_threading.c index 309ae7265e..d10b13269c 100644 --- a/frame/3/bli_l3_smart_threading.c +++ b/frame/3/bli_l3_smart_threading.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -250,7 +250,7 @@ static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher max_available_nt, cntx, rntm ); } - else if ( id == BLIS_ARCH_ZEN4 ) + else if ( id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4 ) { ret_val = bli_gemm_ic_jc_optimum_sup_zen4 ( diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index 252601b742..8fe977d4a5 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,9 +107,12 @@ err_t bli_gemmsup if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } else { rntm_l = *rntm; rntm = &rntm_l; } -#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) +#if defined(BLIS_FAMILY_ZEN5) || defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) - if((bli_arch_query_id() == BLIS_ARCH_ZEN4)) + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + if((id == BLIS_ARCH_ZEN5) || (id == BLIS_ARCH_ZEN4)) { if(( bli_obj_dt(a) == BLIS_DOUBLE ) || ( bli_obj_dt(a) == BLIS_DCOMPLEX)) { @@ -202,6 +205,15 @@ err_t bli_gemmtsup return BLIS_FAILURE; #endif +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) + if (bli_cpuid_is_avx2fma3_supported() == FALSE){ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "AVX instruction is not supported"); + return BLIS_FAILURE; + } +#else + return BLIS_FAILURE; +#endif + // Return early if this is a mixed-datatype computation. if ( bli_obj_dt( c ) != bli_obj_dt( a ) || bli_obj_dt( c ) != bli_obj_dt( b ) || diff --git a/frame/3/bli_l3_sup_int.h b/frame/3/bli_l3_sup_int.h index 09ecda6268..0bb4ae5eef 100644 --- a/frame/3/bli_l3_sup_int.h +++ b/frame/3/bli_l3_sup_int.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +32,7 @@ */ -err_t bli_gemmsup_int +BLIS_EXPORT_BLIS err_t bli_gemmsup_int ( obj_t* alpha, obj_t* a, diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index 3b8fce5b3e..cbf5e46f6b 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -427,7 +427,29 @@ err_t bli_gemmtsup_int /* Enable packing for A matrix for higher sizes. Note that pack A * * becomes pack B inside var2m because this is transpose case*/ - if(bli_is_double(dt) && (n_threads==1)) + arch_t cpu_id = bli_arch_query_id(); + /* Do not pack A for ZEN4 and ZEN5 because the GEMM kernels + * used are column major and GEMMT kernels used are row major. + * Packing matrix A makes matrix B in the GEMMT kernels column + * major which is not supported by row major kernels. + * + * C<- alpha * op(A) *op(B) + beta * C. + * C(nxn) - A(n x k) * B(k x n) + * DGEMM is col-preferred kernel + * DGEMMT = DGEMM + DGEMMT + * DGEMM is col-preferred and DGEMMT is row-preferred. + * DGEMM is evaluated as C = A*B (all col-storage) + * whereas DGEMMT is evaluated as C = B * A (row-storage). + * When A is packed it is packed as row-panels with + * col-stored elements. + * So DGEMM is evaluated as C = A*B (A is col-stored) + * it aligns with col-stored preference. + * For DGEMMT: C = B * A, here A will become col-stored because of packing + * and as result it will break the DGEMMT kernel assumption that A is + * row-storage. + **/ + if( ( cpu_id != BLIS_ARCH_ZEN4 && cpu_id != BLIS_ARCH_ZEN5) && + bli_is_double(dt) && (n_threads==1)) { if((m > 320) && (k > 50)) bli_rntm_set_pack_a( 1, rntm ); diff --git a/frame/3/bli_l3_sup_ker_prot.h b/frame/3/bli_l3_sup_ker_prot.h index 65ecbecb81..5dbfeefe94 100644 --- a/frame/3/bli_l3_sup_ker_prot.h +++ b/frame/3/bli_l3_sup_ker_prot.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,7 +71,7 @@ err_t PASTEMAC0(opname) \ #define TRSMSMALL_KER_PROT( ch, opname ) \ \ -BLIS_INLINE err_t PASTEMAC(ch,opname) \ +err_t PASTEMAC(ch,opname) \ ( \ obj_t* AlphaObj, \ obj_t* a, \ diff --git a/frame/3/bli_l3_sup_packm_var.c b/frame/3/bli_l3_sup_packm_var.c index 70a5148128..052892a601 100644 --- a/frame/3/bli_l3_sup_packm_var.c +++ b/frame/3/bli_l3_sup_packm_var.c @@ -275,29 +275,6 @@ bli_thread_barrier( thread ); \ bli_thread_barrier( thread ); \ } \ */ -/* - if ( bli_is_4mi_packed( schema ) ) { \ - printf( "packm_var2: is_p_use = %lu\n", is_p_use ); \ - if ( col_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: a_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: ap_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - if ( row_stored ) { \ - if ( 0 ) \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: b_r", *m_panel_use, *n_panel_use, \ - ( ctype_r* )c_use, 2*rs_c, 2*cs_c, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_r", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ - PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_i", *m_panel_max, *n_panel_max, \ - ( ctype_r* )p_use + is_p_use, rs_p, cs_p, "%4.1f", "" ); \ - } \ - } \ -*/ /* PASTEMAC(chr,fprintm)( stdout, "packm_var2: bp_rpi", *m_panel_max, *n_panel_max, \ ( ctype_r* )p_use, rs_p, cs_p, "%4.1f", "" ); \ diff --git a/frame/3/bli_l3_tapi.c b/frame/3/bli_l3_tapi.c index 5cd67f968c..0c0644609d 100644 --- a/frame/3/bli_l3_tapi.c +++ b/frame/3/bli_l3_tapi.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Texas at Austin Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -33,18 +33,16 @@ */ -// Guard the function definitions so that they are only compiled when -// #included from files that define the typed API macros. -#ifdef BLIS_ENABLE_TAPI +#include "blis.h" // -// Define BLAS-like interfaces with typed operands. +// Define BLAS-like interfaces with typed operands (basic). // #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ trans_t transa, \ trans_t transb, \ @@ -56,56 +54,70 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, beta, &betao ); \ -\ - bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + transa, \ + transb, \ + m, n, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } INSERT_GENTFUNC_BASIC0( gemm ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ + ( \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +} + INSERT_GENTFUNC_BASIC0( gemmt ) + #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, struca ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -118,50 +130,24 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, beta, &betao ); \ -\ - bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( struca, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploa, \ + conja, \ + transb, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -172,7 +158,7 @@ INSERT_GENTFUNC_BASIC( symm, BLIS_SYMMETRIC ) #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -182,44 +168,21 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_init_finish_1x1( dt_r, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ -\ - bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -229,7 +192,7 @@ INSERT_GENTFUNCR_BASIC0( herk ) #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -241,50 +204,23 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ -\ - bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -294,7 +230,7 @@ INSERT_GENTFUNCR_BASIC0( her2k ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -304,43 +240,21 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, beta, &betao ); \ -\ - bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -350,7 +264,7 @@ INSERT_GENTFUNC_BASIC0( syrk ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -362,49 +276,23 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, beta, &betao ); \ -\ - bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploc, \ + transa, \ + transb, \ + m, k, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -414,7 +302,7 @@ INSERT_GENTFUNC_BASIC0( syr2k ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -428,51 +316,25 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, beta, &betao ); \ -\ - bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ + uploa, \ + transa, \ + diaga, \ + transb, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + beta, \ + c, rs_c, cs_c, \ + NULL, \ + NULL \ ); \ } @@ -482,7 +344,7 @@ INSERT_GENTFUNC_BASIC0( trmm3 ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -493,48 +355,25 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* b, inc_t rs_b, inc_t cs_b \ - BLIS_TAPI_EX_PARAMS \ ) \ { \ - bli_init_once(); \ -\ - BLIS_TAPI_EX_DECLS \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ -\ - bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + /* Invoke the expert interface and request default cntx_t and rntm_t + objects. */ \ + PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ + uploa, \ + transa, \ + diaga, \ + m, n, \ + alpha, \ + a, rs_a, cs_a, \ + b, rs_b, cs_b, \ + NULL, \ + NULL \ ); \ } INSERT_GENTFUNC_BASIC0( trmm ) INSERT_GENTFUNC_BASIC0( trsm ) - -#endif - diff --git a/frame/3/bli_l3_tapi.h b/frame/3/bli_l3_tapi.h index 77a6bd25c0..a101f6c6a1 100644 --- a/frame/3/bli_l3_tapi.h +++ b/frame/3/bli_l3_tapi.h @@ -35,13 +35,13 @@ // -// Prototype BLAS-like interfaces with typed operands. +// Prototype BLAS-like interfaces with typed operands (basic). // #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ trans_t transa, \ trans_t transb, \ @@ -53,16 +53,14 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( gemm ) -INSERT_GENTPROT_BASIC0( gemmt ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -75,7 +73,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( hemm ) @@ -85,7 +82,7 @@ INSERT_GENTPROT_BASIC0( symm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -95,7 +92,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROTR_BASIC0( herk ) @@ -104,7 +100,7 @@ INSERT_GENTPROTR_BASIC0( herk ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -116,7 +112,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROTR_BASIC0( her2k ) @@ -125,7 +120,7 @@ INSERT_GENTPROTR_BASIC0( her2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -135,7 +130,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( syrk ) @@ -144,7 +138,7 @@ INSERT_GENTPROT_BASIC0( syrk ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -156,16 +150,16 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); +INSERT_GENTPROT_BASIC0( gemmt ) INSERT_GENTPROT_BASIC0( syr2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -179,7 +173,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( trmm3 ) @@ -188,7 +181,7 @@ INSERT_GENTPROT_BASIC0( trmm3 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ side_t side, \ uplo_t uploa, \ @@ -199,7 +192,6 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ ctype* b, inc_t rs_b, inc_t cs_b \ - BLIS_TAPI_EX_PARAMS \ ); INSERT_GENTPROT_BASIC0( trmm ) diff --git a/frame/3/bli_l3_tapi_ex.c b/frame/3/bli_l3_tapi_ex.c index 609bf8e78d..8c7682c76c 100644 --- a/frame/3/bli_l3_tapi_ex.c +++ b/frame/3/bli_l3_tapi_ex.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,13 +35,553 @@ #include "blis.h" -// Include cpp macros that instantiate the API definition templates as -// having expert parameters. -#include "bli_tapi_ex.h" +// +// Define BLAS-like interfaces with typed operands (expert). +// -// Define the macro protecting the typed API definitions. -#define BLIS_ENABLE_TAPI +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} -// Include the typed API definitions here. -#include "bli_l3_tapi.c" +INSERT_GENTFUNC_BASIC0( gemm ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, struca ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + conj_t conja, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ + bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC( hemm, BLIS_HERMITIAN ) +INSERT_GENTFUNC_BASIC( symm, BLIS_SYMMETRIC ) + + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + dim_t m, \ + dim_t k, \ + ctype_r* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype_r* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ +\ + bli_obj_init_finish_1x1( dt_r, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNCR_BASIC0( herk ) + + +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype_r* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNCR_BASIC0( her2k ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( syrk ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( syr2k ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + uplo_t uploc, \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, m, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploc, &co ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( gemmt ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ + dim_t m_b, n_b; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ + bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_diag( diaga, &ao ); \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( trmm3 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,BLIS_OAPI_EX_SUF) \ + ( \ + side_t side, \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ + bli_init_once(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn_a; \ +\ + bli_set_dim_with_side( side, m, n, &mn_a ); \ +\ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, n, b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( uploa, &ao ); \ + bli_obj_set_diag( diaga, &ao ); \ + bli_obj_set_conjtrans( transa, &ao ); \ +\ + bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ +\ + PASTEMAC(opname,BLIS_OAPI_EX_SUF) \ + ( \ + side, \ + &alphao, \ + &ao, \ + &bo, \ + cntx, \ + rntm \ + ); \ +} + +INSERT_GENTFUNC_BASIC0( trmm ) +INSERT_GENTFUNC_BASIC0( trsm ) diff --git a/frame/ind/tapi/bli_l3_ind_tapi.h b/frame/3/bli_l3_tapi_ex.h similarity index 63% rename from frame/ind/tapi/bli_l3_ind_tapi.h rename to frame/3/bli_l3_tapi_ex.h index 49ff6a8739..702c6c1a48 100644 --- a/frame/ind/tapi/bli_l3_ind_tapi.h +++ b/frame/3/bli_l3_tapi_ex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,10 +34,14 @@ */ +// +// Prototype BLAS-like interfaces with typed operands (expert). +// + #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ trans_t transa, \ trans_t transb, \ @@ -52,18 +57,12 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( gemm3mh ) -INSERT_GENTPROT_BASIC0( gemm3m1 ) -INSERT_GENTPROT_BASIC0( gemm4mh ) -INSERT_GENTPROT_BASIC0( gemm4mb ) -INSERT_GENTPROT_BASIC0( gemm4m1 ) -INSERT_GENTPROT_BASIC0( gemm1m ) - +INSERT_GENTPROT_BASIC0( gemm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -80,144 +79,99 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( hemm3mh ) -INSERT_GENTPROT_BASIC0( hemm3m1 ) -INSERT_GENTPROT_BASIC0( hemm4mh ) -INSERT_GENTPROT_BASIC0( hemm4m1 ) -INSERT_GENTPROT_BASIC0( hemm1m ) +INSERT_GENTPROT_BASIC0( hemm ) +INSERT_GENTPROT_BASIC0( symm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ - trans_t transb, \ dim_t m, \ dim_t k, \ - ctype* alpha, \ + ctype_r* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ - rntm_t* rntmx \ + rntm_t* rntm \ ); -INSERT_GENTPROTR_BASIC0( her2k3mh ) -INSERT_GENTPROTR_BASIC0( her2k3m1 ) -INSERT_GENTPROTR_BASIC0( her2k4mh ) -INSERT_GENTPROTR_BASIC0( her2k4m1 ) -INSERT_GENTPROTR_BASIC0( her2k1m ) +INSERT_GENTPROTR_BASIC0( herk ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ + trans_t transb, \ dim_t m, \ dim_t k, \ - ctype_r* alpha, \ + ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ ctype_r* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntmx \ - ); - -INSERT_GENTPROTR_BASIC0( herk3mh ) -INSERT_GENTPROTR_BASIC0( herk3m1 ) -INSERT_GENTPROTR_BASIC0( herk4mh ) -INSERT_GENTPROTR_BASIC0( herk4m1 ) -INSERT_GENTPROTR_BASIC0( herk1m ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ + cntx_t* cntx, \ + rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( symm3mh ) -INSERT_GENTPROT_BASIC0( symm3m1 ) -INSERT_GENTPROT_BASIC0( symm4mh ) -INSERT_GENTPROT_BASIC0( symm4m1 ) -INSERT_GENTPROT_BASIC0( symm1m ) +INSERT_GENTPROTR_BASIC0( her2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ - trans_t transb, \ dim_t m, \ dim_t k, \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( syr2k3mh ) -INSERT_GENTPROT_BASIC0( syr2k3m1 ) -INSERT_GENTPROT_BASIC0( syr2k4mh ) -INSERT_GENTPROT_BASIC0( syr2k4m1 ) -INSERT_GENTPROT_BASIC0( syr2k1m ) +INSERT_GENTPROT_BASIC0( syrk ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ + trans_t transb, \ dim_t m, \ dim_t k, \ ctype* alpha, \ ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ ctype* beta, \ ctype* c, inc_t rs_c, inc_t cs_c, \ cntx_t* cntx, \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( syrk3mh ) -INSERT_GENTPROT_BASIC0( syrk3m1 ) -INSERT_GENTPROT_BASIC0( syrk4mh ) -INSERT_GENTPROT_BASIC0( syrk4m1 ) -INSERT_GENTPROT_BASIC0( syrk1m ) +INSERT_GENTPROT_BASIC0( gemmt ) +INSERT_GENTPROT_BASIC0( syr2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -235,40 +189,13 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( trmm33mh ) -INSERT_GENTPROT_BASIC0( trmm33m1 ) -INSERT_GENTPROT_BASIC0( trmm34mh ) -INSERT_GENTPROT_BASIC0( trmm34m1 ) -INSERT_GENTPROT_BASIC0( trmm31m ) - - -#undef GENTPROT -#define GENTPROT( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ); - -INSERT_GENTPROT_BASIC0( trmm3m1 ) -INSERT_GENTPROT_BASIC0( trmm4m1 ) -INSERT_GENTPROT_BASIC0( trmm1m ) +INSERT_GENTPROT_BASIC0( trmm3 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,BLIS_TAPI_EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -283,7 +210,6 @@ void PASTEMAC(ch,opname) \ rntm_t* rntm \ ); -INSERT_GENTPROT_BASIC0( trsm3m1 ) -INSERT_GENTPROT_BASIC0( trsm4m1 ) -INSERT_GENTPROT_BASIC0( trsm1m ) +INSERT_GENTPROT_BASIC0( trmm ) +INSERT_GENTPROT_BASIC0( trsm ) diff --git a/frame/3/bli_l3_thrinfo.h b/frame/3/bli_l3_thrinfo.h index 4e6406acd9..a2a9218a2d 100644 --- a/frame/3/bli_l3_thrinfo.h +++ b/frame/3/bli_l3_thrinfo.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -93,7 +93,7 @@ void bli_l3_thrinfo_free thrinfo_t* thread ); -void bli_l3_sup_thrinfo_free +BLIS_EXPORT_BLIS void bli_l3_sup_thrinfo_free ( rntm_t* rntm, thrinfo_t* thread @@ -110,7 +110,7 @@ void bli_l3_thrinfo_create_root thrinfo_t** thread ); -void bli_l3_sup_thrinfo_create_root +BLIS_EXPORT_BLIS void bli_l3_sup_thrinfo_create_root ( dim_t id, thrcomm_t* gl_comm, diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 063f40ff9c..7941d7a910 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,10 +54,6 @@ void bli_gemm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - // If C has a zero dimension, return early. if ( bli_obj_has_zero_dim( c ) ) { @@ -79,6 +75,29 @@ void bli_gemm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + #ifdef BLIS_ENABLE_GEMM_MD cntx_t cntx_local; @@ -98,24 +117,8 @@ void bli_gemm_front // is adjusted to point to cntx_local.) bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); } - //else // homogeneous datatypes #endif - // Load the pack schemas from the context and embed them into the objects - // for A and B. (Native contexts are initialized with the correct pack - // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would - // have made a copy and modified the schemas, so reading them from the - // context should be a safe bet at this point.) This is a sort of hack for - // communicating the desired pack schemas to bli_gemm_cntl_create() (via - // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us - // to subsequently access the schemas from the control tree, which - // hopefully reduces some confusion, particularly in bli_packm_init(). - const pack_t schema_a = bli_cntx_schema_a_block( cntx ); - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Next, we handle the possibility of needing to typecast alpha to the // computation datatype and/or beta to the storage datatype of C. @@ -242,7 +245,7 @@ void bli_gemm_front bli_obj_set_exec_dt( dt_exec, &ct ); bli_obj_set_comp_dt( dt_comp, &ct ); - // A naive approach would cast C to the comptuation datatype, + // A naive approach would cast C to the computation datatype, // compute with beta, and then cast the result back to the // user-provided output matrix. However, we employ a different // approach that halves the number of memops on C (or its diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index b64baf0001..991b04e56c 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -89,6 +89,29 @@ void bli_gemm_front bli_obj_alias_to( b, &b_local ); bli_obj_alias_to( c, &c_local ); +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + #ifdef BLIS_ENABLE_GEMM_MD cntx_t cntx_local; @@ -111,21 +134,6 @@ void bli_gemm_front //else // homogeneous datatypes #endif - // Load the pack schemas from the context and embed them into the objects - // for A and B. (Native contexts are initialized with the correct pack - // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would - // have made a copy and modified the schemas, so reading them from the - // context should be a safe bet at this point.) This is a sort of hack for - // communicating the desired pack schemas to bli_gemm_cntl_create() (via - // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us - // to subsequently access the schemas from the control tree, which - // hopefully reduces some confusion, particularly in bli_packm_init(). - const pack_t schema_a = bli_cntx_schema_a_block( cntx ); - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Next, we handle the possibility of needing to typecast alpha to the // computation datatype and/or beta to the storage datatype of C. @@ -136,9 +144,11 @@ void bli_gemm_front // In case of dzgemm, if the microkernel prefers column output, // we will induce a transposition and perform C+= A*B // where A( formerly B) is complex. Hence attach alpha to A. +#ifdef BLIS_ENABLE_GEMM_MD if ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local )) bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &a_local ); else +#endif bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); // Attach beta to C, and in the process typecast beta to the target @@ -275,7 +285,7 @@ void bli_gemm_front bli_obj_set_exec_dt( dt_exec, &ct ); bli_obj_set_comp_dt( dt_comp, &ct ); - // A naive approach would cast C to the comptuation datatype, + // A naive approach would cast C to the computation datatype, // compute with beta, and then cast the result back to the // user-provided output matrix. However, we employ a different // approach that halves the number of memops on C (or its diff --git a/frame/3/gemm/bli_gemm_int.c b/frame/3/gemm/bli_gemm_int.c index 405c74d76b..34b4e56e98 100644 --- a/frame/3/gemm/bli_gemm_int.c +++ b/frame/3/gemm/bli_gemm_int.c @@ -60,7 +60,8 @@ void bli_gemm_int bli_gemm_basic_check( alpha, a, b, beta, c, cntx ); // If C has a zero dimension, return early. - if ( bli_obj_has_zero_dim( c ) ) { + if ( bli_obj_has_zero_dim( c ) ) + { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); return; } @@ -69,9 +70,9 @@ void bli_gemm_int if ( bli_obj_has_zero_dim( a ) || bli_obj_has_zero_dim( b ) ) { - if ( bli_thread_am_ochief( thread ) ) - bli_scalm( beta, c ); - bli_thread_barrier( thread ); + if ( bli_thread_am_ochief( thread ) ) + bli_scalm( beta, c ); + bli_thread_barrier( thread ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); return; } @@ -84,9 +85,9 @@ void bli_gemm_int // This should never execute. bli_abort(); - if ( bli_thread_am_ochief( thread ) ) - bli_scalm( beta, c ); - bli_thread_barrier( thread ); + if ( bli_thread_am_ochief( thread ) ) + bli_scalm( beta, c ); + bli_thread_barrier( thread ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); return; } @@ -100,14 +101,14 @@ void bli_gemm_int // to B. if ( !bli_obj_equals( alpha, &BLIS_ONE ) ) { - bli_obj_scalar_apply_scalar( alpha, &b_local ); + bli_obj_scalar_apply_scalar( alpha, &b_local ); } // If beta is non-unit, typecast and apply it to the scalar attached // to C. if ( !bli_obj_equals( beta, &BLIS_ONE ) ) { - bli_obj_scalar_apply_scalar( beta, &c_local ); + bli_obj_scalar_apply_scalar( beta, &c_local ); } // Create the next node in the thrinfo_t structure. @@ -116,17 +117,6 @@ void bli_gemm_int // Extract the function pointer from the current control tree node. f = bli_cntl_var_func( cntl ); - // Somewhat hackish support for 4m1b method implementation. - { - ind_t im = bli_cntx_method( cntx ); - - if ( im != BLIS_NAT ) - { - if ( im == BLIS_4M1B ) - if ( f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2; - } - } - // Invoke the variant. f ( @@ -136,7 +126,7 @@ void bli_gemm_int cntx, rntm, cntl, - thread + thread ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index f91e22d435..a536dcc135 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -171,8 +171,36 @@ void bli_gemm_ker_var2 // function pointer. f = ftypes[dt_exec]; - // Invoke the function. - f( schema_a, +#ifdef BLIS_KERNELS_ZEN5 + const long MR = 8; + const long NR = 24; + + // Optimizes macro kernel is avaible for DGEMM + // for ZEN5. This optimized macro kernel does not support + // fringe cases. Only row major stored C is supported. + // TODO: Add macro kernel function pointer in cntx + if + ( + ( bli_obj_dt( c ) == BLIS_DOUBLE ) && + ( bli_arch_query_id() == BLIS_ARCH_ZEN5 ) && + ( cs_c == 1 ) && // use this kernel only for row major C + ( (n%NR) == 0 ) && ( (m%MR) == 0 ) && + // use generic macro kernel for mixed precision + ( bli_obj_elem_size( a ) == 8 ) && // check if elem_sizeof(a) == sizeof(double) + ( bli_obj_is_real( a ) ) && // check if A is real + ( bli_obj_elem_size( b ) == 8 ) && // check if elem_sizeof(b) == sizeof(double) + ( bli_obj_is_real( b ) ) // check if B is real + ) + { + bli_dgemm_avx512_asm_8x24_macro_kernel + ( + n, m, k, buf_c, buf_a, buf_b, rs_c, buf_beta + ); + } + else +#endif + { + f( schema_a, schema_b, m, n, @@ -187,6 +215,7 @@ void bli_gemm_ker_var2 cntx, rntm, thread ); + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); } @@ -224,7 +253,17 @@ void PASTEMAC(ch,varname) \ /*const dim_t PACKNR = rs_b;*/ \ \ /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ + function pointer type. Note that the virtual gemm ukernel is queried + instead of the native gemm ukernel. This is needed for certain + situations for the 1m method that require an extra layer of logic + to allow for handling (for example) complex values of beta. Also + note that under certain circumstances, the real-domain version of + this macrokernel will be called for 1m (NOT the complex version) + as an optimization. In these cases, the corresponding real-domain + slots within the cntx_t's virtual gemm ukernel func_t will contain + pointers to the *native* gemm ukernel, thanks to logic in the + context initialization function for the induced method (defined + in bli_cntx_ref.c). */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ \ diff --git a/frame/3/gemm/bli_gemm_ker_var2_md.c b/frame/3/gemm/bli_gemm_ker_var2_md.c index 3df524dd2e..09c279d149 100644 --- a/frame/3/gemm/bli_gemm_ker_var2_md.c +++ b/frame/3/gemm/bli_gemm_ker_var2_md.c @@ -368,8 +368,6 @@ void PASTEMAC2(chc,che,varname) \ then accumulate it into C via the xpbys_mxn macro. */ \ /*if ( 1 )*/ \ { \ - /*bli_auxinfo_set_dt_on_output( dte, &aux );*/ \ -\ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ @@ -392,48 +390,6 @@ void PASTEMAC2(chc,che,varname) \ c11, rs_c, cs_c \ ); \ } \ -/* - else if ( m_cur == MR && n_cur == NR ) \ - { \ - bli_auxinfo_set_dt_on_output( dtc, &aux ); \ -\ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - ( ctype_e* )beta_cast, \ - ( ctype_e* )c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - bli_auxinfo_set_dt_on_output( dte, &aux ); \ -\ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - PASTEMAC3(che,chc,chc,xpbys_mxn) \ - ( \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c \ - ); \ - } \ -*/ \ } \ } \ \ diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index 66f8414a27..c21e3f0ef1 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -187,6 +187,10 @@ mddm_t bli_gemm_md_ccr bli_obj_induce_trans( b ); bli_obj_induce_trans( c ); + // We must swap the pack schemas because the schemas were set before + // the objects were swapped. + bli_obj_swap_pack_schemas( a, b ); + return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); } @@ -230,7 +234,7 @@ mddm_t bli_gemm_md_ccr bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc ); bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc ); - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); @@ -288,6 +292,10 @@ mddm_t bli_gemm_md_crc bli_obj_induce_trans( b ); bli_obj_induce_trans( c ); + // We must swap the pack schemas because the schemas were set before + // the objects were swapped. + bli_obj_swap_pack_schemas( a, b ); + return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx ); } @@ -331,7 +339,7 @@ mddm_t bli_gemm_md_crc bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_nc ); bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_nc ); - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); @@ -405,8 +413,8 @@ mddm_t bli_gemm_md_rcc // Use the 1r pack schema for both A and B with the conjugation // of A or B toggled (to produce ar * br - ai * bi). - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1R, *cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1R, *cntx ); + bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS_1R, a ); + bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS_1R, b ); bli_obj_toggle_conj( b ); @@ -485,7 +493,7 @@ mddm_t bli_gemm_md_crr } #endif - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -523,7 +531,7 @@ mddm_t bli_gemm_md_rcr // Overwrite the complex obj_t with its real-only alias. *a = a_real; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -561,7 +569,7 @@ mddm_t bli_gemm_md_rrc // Overwrite the complex obj_t with its real-only alias. *b = b_real; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -591,7 +599,7 @@ mddm_t bli_gemm_md_rrr doms.comp = BLIS_REAL; doms.exec = BLIS_REAL; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; @@ -621,248 +629,10 @@ mddm_t bli_gemm_md_ccc doms.comp = BLIS_COMPLEX; doms.exec = BLIS_COMPLEX; - // Use the default pack schemas in the context. + // Use the default pack schemas in the objects. // Return the computation and execution domains. return doms; } -// ----------------------------------------------------------------------------- - -#if 0 -void bli_gemm_md_front - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t a_local; - obj_t b_local; - obj_t c_local; - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &b_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &c_local ); - } - - cntx_t cntx_local; - - // Handle mixed domain cases in bli_gemm_md(), which may modify - // the objects or the context. (If the context is modified, cntx - // is adjusted to point to cntx_local.) - bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); - - // Record the threading for each level within the context. - bli_rntm_set_ways_for_op - ( - BLIS_GEMM, - BLIS_LEFT, // ignored for gemm/hemm/symm - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // Invoke the internal back-end via the thread handler. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_GEMM, // operation family id - alpha, - &a_local, - &b_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); -} - -// ----------------------------------------------------------------------------- - -void bli_gemm_md_zgemm - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl - ) -{ - bli_init_once(); - - obj_t a_local; - obj_t b_local; - obj_t c_local; - -#if 1 - obj_t am, bm, cm; - obj_t* c_orig; - - //if ( is_md == TRUE ) - { - //num_t dt_c2 = bli_obj_dt( c ); - //num_t dt_c1 = bli_dt_proj_to_complex( dt_c2 ); - //num_t dt_c = bli_dt_proj_to_double_prec( dt_c1 ); - //num_t dt_c = bli_obj_dt_proj_to_complex( c ); - num_t dt_c = BLIS_DCOMPLEX; - - if ( bli_obj_is_single_prec( c ) ) dt_c = BLIS_SCOMPLEX; - else dt_c = BLIS_DCOMPLEX; - - if ( bli_obj_is_real( a ) && - bli_obj_is_real( b ) && - bli_obj_is_real( c ) ) dt_c = bli_dt_proj_to_real( dt_c ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width_after_trans( a ); - - bli_obj_create( dt_c, m, k, 0, 0, &am ); - bli_obj_create( dt_c, k, n, 0, 0, &bm ); - bli_obj_create( dt_c, m, n, 0, 0, &cm ); - - //bli_projm( a, &am ); - //bli_projm( b, &bm ); - //bli_projm( c, &cm ); - bli_castm( a, &am ); - bli_castm( b, &bm ); - bli_castm( c, &cm ); - - c_orig = c; - - a = &am; - b = &bm; - c = &cm; - } -#endif - - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); - - // If alpha is zero, scale by beta and return. - if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) - { - bli_scalm( beta, c ); - return; - } - - // Alias A, B, and C in case we need to apply transformations. - bli_obj_alias_to( a, &a_local ); - bli_obj_alias_to( b, &b_local ); - bli_obj_alias_to( c, &c_local ); - - // An optimization: If C is stored by rows and the micro-kernel prefers - // contiguous columns, or if C is stored by columns and the micro-kernel - // prefers contiguous rows, transpose the entire operation to allow the - // micro-kernel to access elements of C in its preferred manner. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) - { - bli_obj_swap( &a_local, &b_local ); - - bli_obj_induce_trans( &a_local ); - bli_obj_induce_trans( &b_local ); - bli_obj_induce_trans( &c_local ); - } - - { - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - } - - // Parse and interpret the contents of the rntm_t object to properly - // set the ways of parallelism for each loop, and then make any - // additional modifications necessary for the current operation. - bli_rntm_set_ways_for_op - ( - BLIS_GEMM, - BLIS_LEFT, // ignored for gemm/hemm/symm - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), - rntm - ); - - // Invoke the internal back-end via the thread handler. - bli_l3_thread_decorator - ( - bli_gemm_int, - BLIS_GEMM, // operation family id - alpha, - &a_local, - &b_local, - beta, - &c_local, - cntx, - rntm, - cntl - ); - -#if 1 - //if ( is_md == TRUE ) - { - //bli_projm( &cm, c_orig ); - bli_castm( &cm, c_orig ); - - bli_obj_free( &am ); - bli_obj_free( &bm ); - bli_obj_free( &cm ); - } -#endif -} -#endif - #endif diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index cde4c9de8c..d01dd8ad4e 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -62,9 +62,6 @@ GENPROT( gemm_ker_var1 ) GENPROT( gemm_ker_var2 ) -// Headers for induced algorithms: -GENPROT( gemm4mb_ker_var2 ) // 4m1b - // // Prototype BLAS-like interfaces with void pointer operands. @@ -94,6 +91,3 @@ void PASTEMAC(ch,varname) \ INSERT_GENTPROT_BASIC0( gemm_ker_var2 ) -// Headers for induced algorithms: -INSERT_GENTPROT_BASIC0( gemm4mb_ker_var2 ) // 4m1b - diff --git a/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c b/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c deleted file mode 100644 index fbdd12bbc0..0000000000 --- a/frame/3/gemm/ind/bli_gemm4mb_ker_var2.c +++ /dev/null @@ -1,365 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T)( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemm4mb_ker_var2); - - -void bli_gemm4mb_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t ii; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - dim_t jr_inc = jr_num_threads; \ - dim_t ir_inc = ir_num_threads; \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ - \ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* In the 4mb method, we execute the ir loop twice: once for b_r - and once for b_i. */ \ - for ( ii = 0; ii < 2; ++ii ) \ - { \ - ctype* restrict beta_use; \ -\ - if ( ii == 0 ) \ - { \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RO, &aux ); \ - beta_use = beta_cast; \ - } \ - else \ - { \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_IO, &aux ); \ - beta_use = one; \ - } \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter_rr( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter_rr( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3 (4m1b): c before", 8, 6, c11, rs_c, cs_c, "%4.1f", "" );*/ \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_use, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3 (4m1b): c after", 8, 6, c11, rs_c, cs_c, "%4.1f", "" );*/ \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -/*printf( "gemm_ker_var3 (4m1b): returning\n" );*/ \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var3: a1", MR, k, a1, 1, MR, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNC_BASIC0( gemm4mb_ker_var2 ) - diff --git a/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c b/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c deleted file mode 100644 index 5981424ae2..0000000000 --- a/frame/3/gemm/ind/old/bli_gemm3m2_ker_var2.c +++ /dev/null @@ -1,363 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T)( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemm3m2_ker_var2); - - -void bli_gemm3m2_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t ii; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ - \ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* In the 3m2 method, we execute the ir loop thrice: once for - a_r[ir] * b_r, once for a_i[ir] * b_i, and once for - a_{r+i}[ir] * b_{r+i}. */ \ - for ( ii = 0; ii < 3; ++ii ) \ - { \ - ctype* restrict beta_use; \ -\ - if ( ii == 0 ) \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_RO, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RO, &aux ); \ - beta_use = beta_cast; \ - } \ - else if ( ii == 1 ) \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_IO, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_IO, &aux ); \ - beta_use = one; \ - } \ - else \ - { \ - bli_auxinfo_set_schema_a( BLIS_PACKED_ROW_PANELS_RPI, &aux ); \ - bli_auxinfo_set_schema_b( BLIS_PACKED_COL_PANELS_RPI, &aux ); \ - beta_use = one; \ - } \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( caucus, a1, rstep_a ); \ - if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( thread, b1, cstep_b ); \ - if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_use, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm3m2_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm3m2_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNC_BASIC0( gemm3m2_ker_var2 ) - diff --git a/frame/3/gemmt/bli_gemmt_front.c b/frame/3/gemmt/bli_gemmt_front.c index 3a80d80cff..21f7695f15 100644 --- a/frame/3/gemmt/bli_gemmt_front.c +++ b/frame/3/gemmt/bli_gemmt_front.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,10 +54,21 @@ void bli_gemmt_front obj_t b_local; obj_t c_local; + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_gemm_check( alpha, a, b, beta, c, cntx ); + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); @@ -86,20 +97,8 @@ void bli_gemmt_front //else // homogeneous datatypes #endif - // Load the pack schemas from the context and embed them into the objects - // for A and B. (Native contexts are initialized with the correct pack - // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would - // have made a copy and modified the schemas, so reading them from the - // context should be a safe bet at this point.) This is a sort of hack for - // communicating the desired pack schemas to bli_gemm_cntl_create() (via - // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us - // to subsequently access the schemas from the control tree, which - // hopefully reduces some confusion, particularly in bli_packm_init(). - const pack_t schema_a = bli_cntx_schema_a_block( cntx ); - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); + // Set the pack schemas within the objects, as appropriate. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); // Next, we handle the possibility of needing to typecast alpha to the // computation datatype and/or beta to the storage datatype of C. @@ -227,7 +226,7 @@ void bli_gemmt_front bli_obj_set_exec_dt( dt_exec, &ct ); bli_obj_set_comp_dt( dt_comp, &ct ); - // A naive approach would cast C to the comptuation datatype, + // A naive approach would cast C to the computation datatype, // compute with beta, and then cast the result back to the // user-provided output matrix. However, we employ a different // approach that halves the number of memops on C (or its diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m_amd.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m_amd.c index 912b043f70..ac1c85178a 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m_amd.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m_amd.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -75,11 +75,31 @@ typedef void (*gemmt_ker_ft) cntx_t* restrict cntx ); +// these kernels are compiled as part of zen4 config +// use them only when BLIS_KERNELS_ZEN4 is defined +// Look-up table for Gemmt Upper Variant Kernels +#if defined(BLIS_KERNELS_ZEN4) +gemmt_ker_ft ker_fpus_zen4[3] = + { + bli_dgemmsup_rv_zen4_asm_24x8m_upper_0, + bli_dgemmsup_rv_zen4_asm_24x8m_upper_1, + bli_dgemmsup_rv_zen4_asm_24x8m_upper_2 + }; + +//Look-up table for Gemmt Lower Variant Kernels +gemmt_ker_ft ker_fpls_zen4[3] = + { + bli_dgemmsup_rv_zen4_asm_24x8m_lower_0, + bli_dgemmsup_rv_zen4_asm_24x8m_lower_1, + bli_dgemmsup_rv_zen4_asm_24x8m_lower_2 + }; +#endif + // these kernels are compiled as part of haswell config // use them only when BLIS_KERNELS_HASWELL is defined -#ifdef BLIS_KERNELS_HASWELL +#if defined(BLIS_KERNELS_HASWELL) //Look-up table for Gemmt Upper Variant Kernels -gemmt_ker_ft ker_fpus[14] = +gemmt_ker_ft ker_fpus_haswell[14] = { bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U, bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U, @@ -97,7 +117,7 @@ gemmt_ker_ft ker_fpus[14] = bli_dgemmsup_rd_haswell_asm_6x8m_0x0_combined_U}; //Look-up table for Gemmt Lower Variant Kernels -gemmt_ker_ft ker_fpls[14] = +gemmt_ker_ft ker_fpls_haswell[14] = { bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L, bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L, @@ -304,7 +324,6 @@ void bli_gemmtsup_ref_var1n AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); } - #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, uplo, varname ) \ \ @@ -1929,143 +1948,36 @@ void PASTEMACT(ch,opname,uplo,varname) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ \ - /* Prerequisites : MR = 6, NR = 8. - An optimization: allow the last jr iteration to contain up to NRE - In DGEMMT API implementation, kernel operates on 6x8 block. MR and - NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, - the diagonal pattern repeats for every 24x24 block. - This pattern is exploited to achieve the optimization in diagonal - blocks by computing only the required elements. In the previous - implementation, all the 48 outputs of the given 6x8 block are - computed and stored into a temporary buffer. Later, the required - elements are copied into the final C output buffer. - With this optimization, we are avoiding copy operation and also - reducing the number of computations. - Variables m_off_24 and n_off_24 respectively store the m and n - offsets from the starting point of the corresponding 24x24 block. - Variables m_idx and n_idx store indices of the current 6x8 block - along m and n dimensions, in 24x24 block. m_idx is computed as - (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). - Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is - 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, - logic is implemented to identify the relevant kernel from the - look-up table. - During instances, where m is not a multiple of 6 or n is not a - multiple of 8, it goes to the default gemm kernel. MR and NR must be - 6 and 8 for these kernels to achieve the expected functionality.*/ \ -\ - dim_t m_off_24 = m_off_cblock % 24; \ - dim_t n_off_24 = n_off_cblock % 24; \ - dim_t m_idx = (dim_t)(m_off_24 / MR); \ - dim_t n_idx = (dim_t)(n_off_24 / NR); \ -\ - /* Check if m, n indices are multiple of MR and NR respectively - and current block is a complete 6x8 block */ \ - bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0)\ - && (MR == 6) && (NR == 8) \ - && (bli_cpuid_is_avx2fma3_supported() == TRUE) && (mr_cur == MR) && (nr_cur == NR); \ -\ - /* m_idx and n_idx would be equal only if the current block is - a diagonal block */\ - if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && (idx_supported) ) { \ - /* index of kernel in lookup table is 2*m_idx) */ \ - dim_t ker_idx; \ - ker_idx = m_idx<<1; \ -\ - /* If there is another 6x8 diagonal block pending for computation - after the current 6x8 diagonal block, then the two blocks can - be computed together(12x8). This combined kernel is implemented - only for the case where n_idx = 2 i.e., n_off_24 = 16. To call - this, it has to be ensured that at least 12 rows are pending in - C for computation. (m_off + 2 * MR <=m). Usage of this combined - kernel saves the entire time to execute one kernel*/ \ - if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) ) {\ - ker_idx = 6; /* use combined kernel, index of combined kernel - in lookup table is 6 */\ - } \ - /* use rd kernel if B is column major storage */ \ - if( stor_id == BLIS_RRC ) { \ - ker_idx += 7; /* index of rd kernel*/ \ - } \ - gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ - ker_fp \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - (double*) alpha_cast, \ - (double*) a_ir, rs_a_use, cs_a_use, \ - (double*) b_jr, rs_b_use, cs_b_use, \ - (double*) beta_use, \ - (double*) c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ - else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ - /* If current block was already computed in the combined kernel it - can be skipped combined kernel is only implemented for n_idx=2, - i == m_zero is only true for the first iteration therefore if - i == m_zero then the current 6x8 block was not computed in - combined kernel*/ \ - if( (n_idx != 2) || (i == m_zero) ) { \ - dim_t ker_idx = (n_idx << 1) + 1; \ - /* use rd kernel if B is column major storage */ \ - if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ - gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ - ker_fp \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - (double*) alpha_cast, \ - (double*) a_ir, rs_a_use, cs_a_use, \ - (double*) b_jr, rs_b_use, cs_b_use, \ - (double*) beta_use, \ - (double*) c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + if( col_pref ) \ + { \ + PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ } \ - /* Call the regular kernel for non applicable cases */ \ - else { \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ - } \ - else \ - { \ - PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ - }\ + else \ + { \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ }\ \ a_ir += ps_a_use; \ @@ -2122,7 +2034,719 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c */ \ } -INSERT_GENTFUNC_L( gemmtsup, ref_var2m ) +INSERT_GENTFUNC_L_SC( gemmtsup, ref_var2m ) + +/* DGEMMT SUP kernel */ +void bli_dgemmtsup_l_ref_var2m + ( + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t stor_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* restrict zero = PASTEMAC(d,0); + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + + /* If k < 1 or alpha is zero, scale by beta and return. */ + if ( k < 1 || PASTEMAC(d,eq0)( *(( double* )alpha) ) ) + { + if ( bli_thread_am_ochief( thread ) ) + { + PASTEMAC(d,scalm) + ( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, rs_c, cs_c + ); + } + return; + } + + /* Query the context for various blocksizes. */ + dim_t NR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NR, cntx ); + dim_t MR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MR, cntx ); + dim_t NC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NC, cntx ); + dim_t MC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MC, cntx ); + dim_t KC0 = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_KC, cntx ); + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ + dim_t NRM = bli_cntx_get_l3_sup_tri_blksz_max_dt( dt, BLIS_NR, cntx ); + + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ + PASTECH(d,gemmsup_ker_ft) + gemmsup_ker = bli_cntx_get_l3_sup_tri_ker_dt( dt, stor_id, cntx ); + + if( ( 0 == NR ) || ( 0 == MR ) || ( 0 == NC ) || ( 0 == MC ) || ( 0 == KC0 ) ) + { + NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); + } + const dim_t NRE = NRM - NR; + + dim_t KC; + if ( packa && packb ) + { + KC = KC0; + } + else if ( packb ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else if ( packa ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else /* if ( !packa && !packb ) */ + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( m <= MR && n <= NR ) KC = KC0; + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; + else KC = (( KC0 / 5 ) / 4 ) * 4; + } + + /* Compute partitioning step values for each matrix of each loop. */ + const inc_t jcstep_c = cs_c; + const inc_t jcstep_b = cs_b; + + const inc_t pcstep_a = cs_a; + const inc_t pcstep_b = rs_b; + + const inc_t icstep_c = rs_c; + const inc_t icstep_a = rs_a; + + const inc_t jrstep_c = cs_c * NR; + + const inc_t irstep_c = rs_c * MR; + + /* + const inc_t jrstep_b = cs_b * NR; + ( void )jrstep_b; + + const inc_t irstep_c = rs_c * MR; + const inc_t irstep_a = rs_a * MR; + */ + + double ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( double ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + + /* storage-scheme of ct should be same as that of C. + Since update routines only support row-major order, + col_pref flag is used to induce transpose to matrices before + passing to update routine whenever C is col-stored */ + const bool col_pref = (rs_c == 1)? 1 : 0; + + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + double* restrict a_00 = a; + double* restrict b_00 = b; + double* restrict c_00 = c; + double* restrict alpha_cast = alpha; + double* restrict beta_cast = beta; + + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ + double beta_local = *beta_cast; + double one_local = *PASTEMAC(d,1); + + auxinfo_t aux; + + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ + + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); + bli_mem_clear( &mem_b ); + */ + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ + /* 5thloop 4thloop packb 3rdloop packa 2ndloop 1stloop ukrloop */ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t* restrict bszids; + + /* Set the bszids pointer to the correct bszids array above based on which + matrices (if any) are being packed. */ + if ( packa ) { if ( packb ) bszids = bszids_packab; + else bszids = bszids_packa; } + else { if ( packb ) bszids = bszids_packb; + else bszids = bszids_nopack; } + + /* Determine whether we are using more than one thread. */ + const bool is_mt = bli_rntm_calc_num_threads( rntm ); + + thrinfo_t* restrict thread_jc = NULL; + thrinfo_t* restrict thread_pc = NULL; + thrinfo_t* restrict thread_pb = NULL; + thrinfo_t* restrict thread_ic = NULL; + thrinfo_t* restrict thread_pa = NULL; + thrinfo_t* restrict thread_jr = NULL; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jc = bszids; + thread_jc = thread; + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); + + /* Compute the JC loop thread range for the current thread. */ + dim_t jc_start, jc_end; + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_LOWER, m, n, NR, FALSE, &jc_start, &jc_end ); + const dim_t n_local = jc_end - jc_start; + + /* Compute number of primary and leftover components of the JC loop. */ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ + const dim_t jc_left = n_local % NC; + + dim_t m_off_cblock, n_off_cblock; + dim_t m_off = 0; + dim_t n_off = 0; + doff_t diagoffc; + dim_t i, ip; + + /* Loop over the n dimension (NC rows/columns at a time). */ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) + { + /* Calculate the thread's current JC block dimension. */ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); + + double* restrict b_jc = b_00 + jj * jcstep_b; + double* restrict c_jc = c_00 + jj * jcstep_c; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_pc = &bszids_jc[1]; + thread_pc = bli_thrinfo_sub_node( thread_jc ); + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); + + /* Compute the PC loop thread range for the current thread. */ + const dim_t pc_start = 0, pc_end = k; + const dim_t k_local = k; + + /* Compute number of primary and leftover components of the PC loop. */ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ + const dim_t pc_left = k_local % KC; + + /* Loop over the k dimension (KC rows/columns at a time). */ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) + { + /* Calculate the thread's current PC block dimension. */ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); + + double* restrict a_pc = a_00 + pp * pcstep_a; + double* restrict b_pc = b_jc + pp * pcstep_b; + + /* Only apply beta to the first iteration of the pc loop. */ + double* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); + + m_off = 0; + n_off = jj; + diagoffc = m_off - n_off; + + double* b_use; + inc_t rs_b_use, cs_b_use, ps_b_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing B, we alias to + the _pc variables so that code further down can unconditionally + reference the _pb variables. Note that *if* we will be packing + B, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pb; + if ( packb ) { bszids_pb = &bszids_pc[1]; + thread_pb = bli_thrinfo_sub_node( thread_pc ); } + else { bszids_pb = &bszids_pc[0]; + thread_pb = thread_pc; } + + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then a_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(d,packm_sup_b) + ( + packb, + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix B to */ + stor_id, /* a "panel of B." */ + BLIS_NO_TRANSPOSE, + KC, NC, /* This "panel of B" is (at most) KC x NC. */ + kc_cur, nc_cur, NR, + &one_local, + b_pc, rs_b, cs_b, + &b_use, &rs_b_use, &cs_b_use, + &ps_b_use, + cntx, + rntm, + &mem_b, + thread_pb + ); + + /* Alias a_use so that it's clear this is our current block of + matrix B. */ + double* restrict b_pc_use = b_use; + + /* We don't need to embed the panel stride of B within the auxinfo_t + object because this variant iterates through B in the jr loop, + which occurs here, within the macrokernel, not within the + millikernel. */ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_ic = &bszids_pb[1]; + thread_ic = bli_thrinfo_sub_node( thread_pb ); + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); + + /* Compute the IC loop thread range for the current thread. */ + dim_t ic_start, ic_end; + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_UPPER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); + const dim_t m_local = ic_end - ic_start; + + /* Compute number of primary and leftover components of the IC loop. */ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ + const dim_t ic_left = m_local % MC; + + /* Loop over the m dimension (MC rows at a time). */ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) + { + /* Calculate the thread's current IC block dimension. */ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); + dim_t nc_pruned = nc_cur; + + double* restrict a_ic = a_pc + ii * icstep_a; + double* restrict c_ic = c_jc + ii * icstep_c; + + m_off = ii; + + if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; + + diagoffc = m_off - n_off; + + if( diagoffc < 0 ) + { + ip = -diagoffc / MR; + i = ip * MR; + mc_cur = mc_cur - i; + diagoffc = -diagoffc % MR; + m_off += i; + c_ic = c_ic + ( i ) * rs_c; + a_ic = a_ic + ( i ) * rs_a; + } + + if( ( diagoffc + mc_cur ) < nc_cur ) + { + nc_pruned = diagoffc + mc_cur; + } + + double* a_use; + inc_t rs_a_use, cs_a_use, ps_a_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing A, we alias to + the _ic variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pa; + if ( packa ) { bszids_pa = &bszids_ic[1]; + thread_pa = bli_thrinfo_sub_node( thread_ic ); } + else { bszids_pa = &bszids_ic[0]; + thread_pa = thread_ic; } + + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(d,packm_sup_a) + ( + packa, + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix A to */ + stor_id, /* a "block of A." */ + BLIS_NO_TRANSPOSE, + MC, KC, /* This "block of A" is (at most) MC x KC. */ + mc_cur, kc_cur, MR, + &one_local, + a_ic, rs_a, cs_a, + &a_use, &rs_a_use, &cs_a_use, + &ps_a_use, + cntx, + rntm, + &mem_a, + thread_pa + ); + + /* Alias a_use so that it's clear this is our current block of + matrix A. */ + double* restrict a_ic_use = a_use; + + /* Embed the panel stride of A within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of A (if needed). */ + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jr = &bszids_pa[1]; + thread_jr = bli_thrinfo_sub_node( thread_pa ); + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); + + /* Compute number of primary and leftover components of the JR loop. */ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; + dim_t jr_left = nc_pruned % NR; + + /* Compute the JR loop thread range for the current thread. */ + dim_t jr_start, jr_end; + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); + + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing B + since packing an extended edge case is not yet supported. */ + if ( !packb && !is_mt ) + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) + { + jr_iter--; jr_left += NR; + } + + /* Loop over the n dimension (NR columns at a time). */ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) + { + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); + + /* + double* restrict b_jr = b_pc_use + j * jrstep_b; + */ + double* restrict b_jr = b_pc_use + j * ps_b_use; + double* restrict c_jr = c_ic + j * jrstep_c; + + dim_t i; + dim_t m_zero = 0; + dim_t n_iter_zero = 0; + + m_off_cblock = m_off; + n_off_cblock = n_off + j * NR; + + if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) + { + m_zero = 0; + } + else + { + /* compute number of rows that are filled with zeroes and can be ignored */ + n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; + m_zero = n_iter_zero * MR; + } + + double* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; + double* restrict c_ir = c_jr + n_iter_zero * irstep_c; + + /* Ignore the zero region */ + m_off_cblock += m_zero; + + /* Compute the triangular part */ + for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) + { + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; + dim_t m_off_24 = m_off_cblock % 24; + dim_t n_off_24 = n_off_cblock % 24; + dim_t m_idx = (dim_t)(m_off_24 / MR); + dim_t n_idx = (dim_t)(n_off_24 / NR); + #ifdef BLIS_KERNELS_ZEN4 + if ( (MR == 24) && (NR == 8) && bli_cpuid_is_avx512_supported() && + (stor_id != BLIS_CRC && stor_id != BLIS_RRC) && + // verify if micro panel intersects with diagonal + // if distance from diagonal (n_off_cblock - m_off_cblock) is greater + // than (LCM(MR, NR) - NR) then it implies that micro panel is far + // from diagonal therefore it does not intersect with it. + (n_off_cblock - m_off_cblock) <= 16 // (n_off_cblock - m_off_cblock) <= (LCM(MR, NR) - NR) + ) + { + /* + call traingular 24x8 DGEMMT kernels + */ + // Difference between n_off_cblock and m_off_cblock is same as + // the size of empty region before diagonal region. + // kernel_idx = 0 is used when empty region size <= 0 + // kernel_idx = 1 is used when empty region size <= 8 + // kernel_idx = 2 is used when empty region size <= 16 + ker_fpls_zen4[(n_off_cblock - m_off_cblock)/NR] + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double*) alpha_cast, + (double*) a_ir, rs_a_use, cs_a_use, + (double*) b_jr, rs_b_use, cs_b_use, + (double*) beta_use, + (double*) c_ir, rs_c, cs_c, + &aux, + cntx + ); + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + #endif + #ifdef BLIS_KERNELS_HASWELL + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ + + + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) + && (MR == 6) && (NR == 8) + && (bli_cpuid_is_avx2fma3_supported() == TRUE) && (mr_cur == MR) && (nr_cur == NR); + + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && (idx_supported) ) + { + /* index of kernel in lookup table is 2*m_idx) */ + dim_t ker_idx; + ker_idx = m_idx<<1; + + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 2 i.e., n_off_24 = 16. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation. (m_off + 2 * MR <=m). Usage of this combined + kernel saves the entire time to execute one kernel*/ + if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) ) { + ker_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */ + } + /* use rd kernel if B is column major storage */ + if( stor_id == BLIS_RRC ) { + ker_idx += 7; /* index of rd kernel*/ + } + gemmt_ker_ft ker_fp = ker_fpls_haswell[ker_idx]; + ker_fp + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double*) alpha_cast, + (double*) a_ir, rs_a_use, cs_a_use, + (double*) b_jr, rs_b_use, cs_b_use, + (double*) beta_use, + (double*) c_ir, rs_c, cs_c, + &aux, + cntx + ); + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */ + else if ( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) + { + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=2, + i == m_zero is only true for the first iteration therefore if + i == m_zero then the current 6x8 block was not computed in + combined kernel + */ + if ((n_idx != 2) || (i == m_zero)) + { + dim_t ker_idx = (n_idx << 1) + 1; + /* use rd kernel if B is column major storage */ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } + gemmt_ker_ft ker_fp = ker_fpls_haswell[ker_idx]; + ker_fp + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double*) alpha_cast, + (double*) a_ir, rs_a_use, cs_a_use, + (double*) b_jr, rs_b_use, cs_b_use, + (double*) beta_use, + (double*) c_ir, rs_c, cs_c, + &aux, + cntx + ); + } + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + #endif + gemmsup_ker + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + zero, + ct, rs_ct, cs_ct, + &aux, + cntx + ); + if( col_pref ) + { + PASTEMAC(d,update_upper_triang)( n_off_cblock, m_off_cblock, + nr_cur, mr_cur, + ct, cs_ct, rs_ct, + beta_use, + c_ir, cs_c, rs_c ); + } + else + { + PASTEMAC(d,update_lower_triang)( m_off_cblock, n_off_cblock, + mr_cur, nr_cur, + ct, rs_ct, cs_ct, + beta_use, + c_ir, rs_c, cs_c ); + } + + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + } + + /* Invoke the gemmsup millikernel for remaining rectangular part. */ + gemmsup_ker + ( + conja, + conjb, + (i > mc_cur)? 0: mc_cur - i, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + beta_use, + c_ir, rs_c, cs_c, + &aux, + cntx + ); + + } + } + + /* NOTE: This barrier is only needed if we are packing B (since + that matrix is packed within the pc loop of this variant). */ + if ( packb ) bli_thread_barrier( thread_pb ); + } + } + + /* Release any memory that was acquired for packing matrices A and B. */ + PASTEMAC(d,packm_sup_finalize_mem_a) + ( + packa, + rntm, + &mem_a, + thread_pa + ); + PASTEMAC(d,packm_sup_finalize_mem_b) + ( + packb, + rntm, + &mem_b, + thread_pb + ); + +/* +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); +*/ +} #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, uplo, varname ) \ @@ -2621,108 +3245,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ - /* Prerequisites : MR = 6, NR = 8. - An optimization: allow the last jr iteration to contain up to NRE - In DGEMMT API implementation, kernel operates on 6x8 block. MR and - NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, - the diagonal pattern repeats for every 24x24 block. - This pattern is exploited to achieve the optimization in diagonal - blocks by computing only the required elements. In the previous - implementation, all the 48 outputs of the given 6x8 block are - computed and stored into a temporary buffer. Later, the required - elements are copied into the final C output buffer. - With this optimization, we are avoiding copy operation and also - reducing the number of computations. - Variables m_off_24 and n_off_24 respectively store the m and n - offsets from the starting point of the corresponding 24x24 block. - Variables m_idx and n_idx store indices of the current 6x8 block - along m and n dimensions, in 24x24 block. m_idx is computed as - (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). - Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is - 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, - logic is implemented to identify the relevant kernel from the - look-up table. - During instances, where m is not a multiple of 6 or n is not a - multiple of 8, it goes to the default gemm kernel. MR and NR must be - 6 and 8 for these kernels to achieve the expected functionality.*/ \ - dim_t m_off_24 = m_off_cblock % 24; \ - dim_t n_off_24 = n_off_cblock % 24; \ - dim_t m_idx = (dim_t)(m_off_24 / MR); \ - dim_t n_idx = (dim_t)(n_off_24 / NR); \ -\ - /* Check if m, n indices are multiple of MR and NR respectively - and current block is a complete 6x8 block */ \ - bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0)\ - && (MR == 6) && (NR == 8) \ - && (bli_cpuid_is_avx2fma3_supported() == TRUE) && (mr_cur==MR) && (nr_cur==NR); \ -\ - /* m_idx and n_idx would be equal only if the current block is - a diagonal block */\ - if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && idx_supported ) { \ - dim_t ker_idx = m_idx<<1; \ - /* If there is another 6x8 diagonal block pending for computation - after the current 6x8 diagonal block, then the two blocks can - be computed together(12x8). This combined kernel is implemented - only for the case where n_idx = 0 i.e., n_off_24 = 0. To call - this, it has to be ensured that at least 12 rows are pending in - C for computation (i+ MR + MR <= mc_cur). Usage of this combined - kernel saves the entire time to execute one kernel*/ \ - if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) { \ - ker_idx = 6; /* use combined kernel, index of combined kernel - in lookup table is 6 */\ - } \ - /* if B is column storage we use rd kernel*/ \ - if( stor_id == BLIS_RRC ) { \ - ker_idx += 7; /* index of rd kernel*/\ - } \ - gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ - ker_fp \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - (double*) alpha_cast, \ - (double*) a_ir, rs_a_use, cs_a_use, \ - (double*) b_jr, rs_b_use, cs_b_use, \ - (double*) beta_use, \ - (double*) c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ - else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ - /* If current block was already computed in the combined kernel it - can be skipped combined kernel is only implemented for n_idx=0, - i == m_rect is only true for the first iteration therefore if - i == m_rect then the current 6x8 block was not computed in - combined kernel*/ \ - if( (n_idx != 0) || (i == m_rect) ) { \ - dim_t ker_idx = (n_idx << 1) + 1 ; \ - /* use rd kernel if B is column major storage */ \ - if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ - gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ - ker_fp \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - (double*) alpha_cast, \ - (double*) a_ir, rs_a_use, cs_a_use, \ - (double*) b_jr, rs_b_use, cs_b_use, \ - (double*) beta_use, \ - (double*) c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - } \ - /* call the regular kernel for non applicable cases */ \ - else { \ + { \ gemmsup_ker \ ( \ conja, \ @@ -2794,5 +3317,1912 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c */ \ } -INSERT_GENTFUNC_U( gemmtsup, ref_var2m ) +INSERT_GENTFUNC_U_SC( gemmtsup, ref_var2m ) +void bli_dgemmtsup_u_ref_var2m + ( + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t stor_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* restrict zero = PASTEMAC(d,0); + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + + /* If k < 1 or alpha is zero, scale by beta and return. */ + if ( k < 1 || PASTEMAC(d,eq0)( *(( double* )alpha) ) ) + { + if ( bli_thread_am_ochief( thread ) ) + { + PASTEMAC(d,scalm) + ( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, rs_c, cs_c + ); + } + return; + } + + /* Query the context for various blocksizes. */ + dim_t NR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NR, cntx ); + dim_t MR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MR, cntx ); + dim_t NC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NC, cntx ); + dim_t MC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MC, cntx ); + dim_t KC0 = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_KC, cntx ); + + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ + dim_t NRM = bli_cntx_get_l3_sup_tri_blksz_max_dt( dt, BLIS_NR, cntx ); + + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ + PASTECH(d,gemmsup_ker_ft) + gemmsup_ker = bli_cntx_get_l3_sup_tri_ker_dt( dt, stor_id, cntx ); + + if( ( 0 == NR ) || ( 0 == MR ) || ( 0 == NC ) || ( 0 == MC ) || ( 0 == KC0 ) ) + { + NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); + } + const dim_t NRE = NRM - NR; + + dim_t KC; + if ( packa && packb ) + { + KC = KC0; + } + else if ( packb ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else if ( packa ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else /* if ( !packa && !packb ) */ + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR ) + { + if ( m <= 4*MR ) KC = KC0; + else if ( m <= 36*MR ) KC = KC0 / 2; + else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else KC = KC0 / 4; + } + else if ( m <= MR && n <= NR ) KC = KC0; + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; + else KC = (( KC0 / 5 ) / 4 ) * 4; + } + + /* Compute partitioning step values for each matrix of each loop. */ + const inc_t jcstep_c = cs_c; + const inc_t jcstep_b = cs_b; + + const inc_t pcstep_a = cs_a; + const inc_t pcstep_b = rs_b; + + const inc_t icstep_c = rs_c; + const inc_t icstep_a = rs_a; + + const inc_t jrstep_c = cs_c * NR; + + const inc_t irstep_c = rs_c * MR; + + /* + const inc_t jrstep_b = cs_b * NR; + ( void )jrstep_b; + + const inc_t irstep_c = rs_c * MR; + const inc_t irstep_a = rs_a * MR; + */ + + double ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( double ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + + /* Storage scheme of ct should be same as that of C. + Since update routines only support row-major order, + col_pref flag is used to induce transpose to matrices before + passing to update routine whenever C is col-stored */ + const bool col_pref = (rs_c == 1) ? 1 : 0; + + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + double* restrict a_00 = a; + double* restrict b_00 = b; + double* restrict c_00 = c; + double* restrict alpha_cast = alpha; + double* restrict beta_cast = beta; + + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ + double beta_local = *beta_cast; + double one_local = *PASTEMAC(d,1); + + auxinfo_t aux; + + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ + + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); + bli_mem_clear( &mem_b ); + */ + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ + /* 5thloop 4thloop packb 3rdloop packa 2ndloop 1stloop ukrloop */ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t* restrict bszids; + + /* Set the bszids pointer to the correct bszids array above based on which + matrices (if any) are being packed. */ + if ( packa ) { if ( packb ) bszids = bszids_packab; + else bszids = bszids_packa; } + else { if ( packb ) bszids = bszids_packb; + else bszids = bszids_nopack; } + + /* Determine whether we are using more than one thread. */ + const bool is_mt = bli_rntm_calc_num_threads( rntm ); + + thrinfo_t* restrict thread_jc = NULL; + thrinfo_t* restrict thread_pc = NULL; + thrinfo_t* restrict thread_pb = NULL; + thrinfo_t* restrict thread_ic = NULL; + thrinfo_t* restrict thread_pa = NULL; + thrinfo_t* restrict thread_jr = NULL; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jc = bszids; + thread_jc = thread; + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); + + /* Compute the JC loop thread range for the current thread. */ + dim_t jc_start, jc_end; + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_UPPER, m, n, NR, FALSE, &jc_start, &jc_end ); + const dim_t n_local = jc_end - jc_start; + + dim_t m_off = 0; + dim_t n_off = 0; + doff_t diagoffc; + dim_t m_off_cblock, n_off_cblock; + dim_t jp, j; + + /* Compute number of primary and leftover components of the JC loop. */ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ + const dim_t jc_left = n_local % NC; + + /* Loop over the n dimension (NC rows/columns at a time). */ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) + { + /* Calculate the thread's current JC block dimension. */ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); + + double* restrict b_jc = b_00 + jj * jcstep_b; + double* restrict c_jc = c_00 + jj * jcstep_c; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_pc = &bszids_jc[1]; + thread_pc = bli_thrinfo_sub_node( thread_jc ); + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); + + /* Compute the PC loop thread range for the current thread. */ + const dim_t pc_start = 0, pc_end = k; + const dim_t k_local = k; + + /* Compute number of primary and leftover components of the PC loop. */ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ + const dim_t pc_left = k_local % KC; + + /* Loop over the k dimension (KC rows/columns at a time). */ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) + { + /* Calculate the thread's current PC block dimension. */ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); + + double* restrict a_pc = a_00 + pp * pcstep_a; + double* restrict b_pc = b_jc + pp * pcstep_b; + + /* Only apply beta to the first iteration of the pc loop. */ + double* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); + + m_off = 0; + n_off = jj; + diagoffc = m_off - n_off; + + double* b_use; + inc_t rs_b_use, cs_b_use, ps_b_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing B, we alias to + the _pc variables so that code further down can unconditionally + reference the _pb variables. Note that *if* we will be packing + B, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pb; + if ( packb ) { bszids_pb = &bszids_pc[1]; + thread_pb = bli_thrinfo_sub_node( thread_pc ); } + else { bszids_pb = &bszids_pc[0]; + thread_pb = thread_pc; } + + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then a_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(d,packm_sup_b) + ( + packb, + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix B to */ + stor_id, /* a "panel of B." */ + BLIS_NO_TRANSPOSE, + KC, NC, /* This "panel of B" is (at most) KC x NC. */ + kc_cur, nc_cur, NR, + &one_local, + b_pc, rs_b, cs_b, + &b_use, &rs_b_use, &cs_b_use, + &ps_b_use, + cntx, + rntm, + &mem_b, + thread_pb + ); + + /* Alias a_use so that it's clear this is our current block of + matrix B. */ + double* restrict b_pc_use = b_use; + + /* We don't need to embed the panel stride of B within the auxinfo_t + object because this variant iterates through B in the jr loop, + which occurs here, within the macrokernel, not within the + millikernel. */ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_ic = &bszids_pb[1]; + thread_ic = bli_thrinfo_sub_node( thread_pb ); + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); + + /* Compute the IC loop thread range for the current thread. */ + dim_t ic_start, ic_end; + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_LOWER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); + const dim_t m_local = ic_end - ic_start; + + /* Compute number of primary and leftover components of the IC loop. */ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ + const dim_t ic_left = m_local % MC; + + /* Loop over the m dimension (MC rows at a time). */ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) + { + /* Calculate the thread's current IC block dimension. */ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); + + dim_t nc_pruned = nc_cur; + + m_off = ii; + n_off = jj; + + if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; + + double* restrict a_ic = a_pc + ii * icstep_a; + double* restrict c_ic = c_jc + ii * icstep_c; + + doff_t diagoffc = m_off - n_off; + + double* restrict b_pc_pruned = b_pc_use; + + if(diagoffc > 0 ) + { + jp = diagoffc / NR; + j = jp * NR; + nc_pruned = nc_cur - j; + n_off += j; + diagoffc = diagoffc % NR; + c_ic = c_ic + ( j ) * cs_c; + b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; + } + + if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) + { + mc_cur = -diagoffc + nc_pruned; + } + + double* a_use; + inc_t rs_a_use, cs_a_use, ps_a_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing A, we alias to + the _ic variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pa; + if ( packa ) { bszids_pa = &bszids_ic[1]; + thread_pa = bli_thrinfo_sub_node( thread_ic ); } + else { bszids_pa = &bszids_ic[0]; + thread_pa = thread_ic; } + + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(d,packm_sup_a) + ( + packa, + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix A to */ + stor_id, /* a "block of A." */ + BLIS_NO_TRANSPOSE, + MC, KC, /* This "block of A" is (at most) MC x KC. */ + mc_cur, kc_cur, MR, + &one_local, + a_ic, rs_a, cs_a, + &a_use, &rs_a_use, &cs_a_use, + &ps_a_use, + cntx, + rntm, + &mem_a, + thread_pa + ); + + /* Alias a_use so that it's clear this is our current block of + matrix A. */ + double* restrict a_ic_use = a_use; + + /* Embed the panel stride of A within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of A (if needed). */ + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jr = &bszids_pa[1]; + thread_jr = bli_thrinfo_sub_node( thread_pa ); + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); + + /* Compute number of primary and leftover components of the JR loop. */ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; + dim_t jr_left = nc_pruned % NR; + + /* Compute the JR loop thread range for the current thread. */ + dim_t jr_start, jr_end; + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); + + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing B + since packing an extended edge case is not yet supported. */ + if ( !packb && !is_mt ) + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) + { + jr_iter--; jr_left += NR; + } + + /* Loop over the n dimension (NR columns at a time). */ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) + { + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); + + /* + double* restrict b_jr = b_pc_use + j * jrstep_b; + */ + double* restrict b_jr = b_pc_pruned + j * ps_b_use; + double* restrict c_jr = c_ic + j * jrstep_c; + dim_t m_rect = 0; + dim_t n_iter_rect = 0; + + m_off_cblock = m_off; + n_off_cblock = n_off + j * NR; + + if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) + { + m_rect = mc_cur; + } + else + { + /* calculate the number of rows in rectangular region of the block */ + n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; + m_rect = n_iter_rect * MR; + } + + /* Compute the rectangular part */ + gemmsup_ker + ( + conja, + conjb, + m_rect, + nr_cur, + kc_cur, + alpha_cast, + a_ic_use, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + beta_use, + c_jr, rs_c, cs_c, + &aux, + cntx + ); + + m_off_cblock = m_off + m_rect; + + double* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; + double* restrict c_ir = c_jr + n_iter_rect * irstep_c; + + /* compute the remaining triangular part */ + for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) + { + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; + dim_t m_off_24 = m_off_cblock % 24; + dim_t n_off_24 = n_off_cblock % 24; + dim_t m_idx = (dim_t)(m_off_24 / MR); + dim_t n_idx = (dim_t)(n_off_24 / NR); + #ifdef BLIS_KERNELS_ZEN4 + if ( (n_idx == m_idx) && (MR == 24) && (NR == 8) && bli_cpuid_is_avx512_supported() && + (stor_id != BLIS_CRC && stor_id != BLIS_RRC) && + // verify if micro panel intersects with diagonal + // if distance from diagonal (n_off_cblock - m_off_cblock) is greater + // than (LCM(MR, NR) - NR) then it implies that micro panel is far + // from diagonal therefore it it does not intersect with it. + (n_off_cblock - m_off_cblock) <= 16 // (n_off_cblock - m_off_cblock) <= (LCM(MR, NR) - NR) + ) + { + /* + call traingular 24x8 DGEMMT kernels + */ + // Difference between n_off_cblock and m_off_cblock is same as + // the size of full GEMM region. + // kernel_idx = 0 is used when full GEMM region size <= 0 + // kernel_idx = 1 is used when full GEMM region size <= 8 + // kernel_idx = 2 is used when full GEMM region size <= 16 + ker_fpus_zen4[(n_off_cblock - m_off_cblock)/NR] + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double *)alpha_cast, + (double *)a_ir, rs_a_use, cs_a_use, + (double *)b_jr, rs_b_use, cs_b_use, + (double *)beta_use, + (double *)c_ir, rs_c, cs_c, + &aux, + cntx + ); + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + #endif + #ifdef BLIS_KERNELS_HASWELL + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ + // dim_t m_off_24 = m_off_cblock % 24; + // dim_t n_off_24 = n_off_cblock % 24; + // dim_t m_idx = (dim_t)(m_off_24 / MR); + // dim_t n_idx = (dim_t)(n_off_24 / NR); + + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) + && (MR == 6) && (NR == 8) + && (bli_cpuid_is_avx2fma3_supported() == TRUE) && (mr_cur==MR) && (nr_cur==NR); + + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && idx_supported ) + { + dim_t ker_idx = m_idx<<1; + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 0 i.e., n_off_24 = 0. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation (i+ MR + MR <= mc_cur). Usage of this combined + kernel saves the entire time to execute one kernel*/ + if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) { + ker_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */ + } + /* if B is column storage we use rd kernel*/ + if( stor_id == BLIS_RRC ) { + ker_idx += 7; /* index of rd kernel*/ + } + gemmt_ker_ft ker_fp = ker_fpus_haswell[ker_idx]; + ker_fp + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double*) alpha_cast, + (double*) a_ir, rs_a_use, cs_a_use, + (double*) b_jr, rs_b_use, cs_b_use, + (double*) beta_use, + (double*) c_ir, rs_c, cs_c, + &aux, + cntx + ); + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */ + else if ( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) + { + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=0, + i == m_rect is only true for the first iteration therefore if + i == m_rect then the current 6x8 block was not computed in + combined kernel + */ + if ( (n_idx != 0) || (i == m_rect) ) + { + dim_t ker_idx = (n_idx << 1) + 1 ; + /* use rd kernel if B is column major storage */ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } + + gemmt_ker_ft ker_fp = ker_fpus_haswell[ker_idx]; + + ker_fp + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + (double*) alpha_cast, + (double*) a_ir, rs_a_use, cs_a_use, + (double*) b_jr, rs_b_use, cs_b_use, + (double*) beta_use, + (double*) c_ir, rs_c, cs_c, + &aux, + cntx + ); + } + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + continue; + } + #endif + gemmsup_ker + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + zero, + ct, rs_ct, cs_ct, + &aux, + cntx + ); + + if( col_pref ) + { + PASTEMAC(d,update_lower_triang)( n_off_cblock, m_off_cblock, + nr_cur, mr_cur, + ct, cs_ct, rs_ct, + beta_use, + c_ir, cs_c, rs_c ); + } + else + { + PASTEMAC(d,update_upper_triang)( m_off_cblock, n_off_cblock, + mr_cur, nr_cur, + ct, rs_ct, cs_ct, + beta_use, + c_ir, rs_c, cs_c ); + } + + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + + } + } + } + + /* NOTE: This barrier is only needed if we are packing B (since + that matrix is packed within the pc loop of this variant). */ + if ( packb ) bli_thread_barrier( thread_pb ); + } + } + + /* Release any memory that was acquired for packing matrices A and B. */ + PASTEMAC(d,packm_sup_finalize_mem_a) + ( + packa, + rntm, + &mem_a, + thread_pa + ); + PASTEMAC(d,packm_sup_finalize_mem_b) + ( + packb, + rntm, + &mem_b, + thread_pb + ); + +/* +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); +PASTEMAC(d,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); +*/ +} + +/***************************************************************/ +/* AVX512 Kernel - gemmsup_rv_zen4_asm_4x4m */ +/* Check if AVX512 kernel can be called for certain conditions */ +/* 1. Architecture: ZEN4 or ZEN5 */ +/* 2. Storage: If it is CRC, RRC AVX2 code path is invoked */ +/* for other storage formats AVX512 will be called*/ +/* 3. BlockSize: Kernel is optimised for MR=NR=4 */ +/***************************************************************/ +#if defined (BLIS_KERNELS_ZEN4) + +#define LOWER_TRIANGLE_OPTIMIZATION_DCOMPLEX() \ + if ((MR == 4) && (NR == 4) && (stor_id != BLIS_CRC) && (stor_id != BLIS_RRC)) \ + { \ + bli_zgemmsup_rv_zen4_asm_4x4m_lower \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (dcomplex*) alpha_cast, \ + (dcomplex*) a_ir, rs_a_use, cs_a_use, \ + (dcomplex*) b_jr, rs_b_use, cs_b_use, \ + (dcomplex*) beta_use, \ + (dcomplex*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + /* call the regular kernel for non applicable cases */ \ + else + +#define UPPER_TRIANGLE_OPTIMIZATION_DCOMPLEX() \ + if ((MR == 4) && (NR == 4) && (stor_id != BLIS_CRC) && (stor_id != BLIS_RRC)) \ + { \ + bli_zgemmsup_rv_zen4_asm_4x4m_upper \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + (dcomplex*) alpha_cast, \ + (dcomplex*) a_ir, rs_a_use, cs_a_use, \ + (dcomplex*) b_jr, rs_b_use, cs_b_use, \ + (dcomplex*) beta_use, \ + (dcomplex*) c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + /* call the regular kernel for non applicable cases */ \ + else + +#else + #define LOWER_TRIANGLE_OPTIMIZATION_DCOMPLEX() + #define UPPER_TRIANGLE_OPTIMIZATION_DCOMPLEX() + +#endif + +void bli_zgemmtsup_l_ref_var2m + ( \ + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t stor_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ) +{ + const num_t dt = PASTEMAC(z,type); + + dcomplex* restrict zero = PASTEMAC(z,0); + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + + /* If k < 1 or alpha is zero, scale by beta and return. */ + if ( k < 1 || PASTEMAC(z,eq0)( *(( dcomplex* )alpha) ) ) + { + if ( bli_thread_am_ochief( thread ) ) + { + PASTEMAC(z,scalm) + ( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, rs_c, cs_c + ); + } + return; + } + + /* Query the context for various blocksizes. */ + dim_t NR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NR, cntx ); + dim_t MR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MR, cntx ); + dim_t NC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NC, cntx ); + dim_t MC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MC, cntx ); + dim_t KC0 = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_KC, cntx ); + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ + dim_t NRM = bli_cntx_get_l3_sup_tri_blksz_max_dt( dt, BLIS_NR, cntx ); + + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ + PASTECH(z,gemmsup_ker_ft) + gemmsup_ker = bli_cntx_get_l3_sup_tri_ker_dt( dt, stor_id, cntx ); + + if( ( 0 == NR ) || ( 0 == MR ) || ( 0 == NC ) || ( 0 == MC ) || ( 0 == KC0 ) ) + { + NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); + } + const dim_t NRE = NRM - NR; + + dim_t KC; + if ( packa && packb ) + { + KC = KC0; + } + else if ( packb ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else if ( packa ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else /* if ( !packa && !packb ) */ + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( m <= MR && n <= NR ) KC = KC0; + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; + else KC = (( KC0 / 5 ) / 4 ) * 4; + } + + /* Compute partitioning step values for each matrix of each loop. */ + const inc_t jcstep_c = cs_c; + const inc_t jcstep_b = cs_b; + + const inc_t pcstep_a = cs_a; + const inc_t pcstep_b = rs_b; + + const inc_t icstep_c = rs_c; + const inc_t icstep_a = rs_a; + + const inc_t jrstep_c = cs_c * NR; + + const inc_t irstep_c = rs_c * MR; + + /* + const inc_t jrstep_b = cs_b * NR; + ( void )jrstep_b; + + const inc_t irstep_c = rs_c * MR; + const inc_t irstep_a = rs_a * MR; + */ + + dcomplex ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( dcomplex ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + + /* storage-scheme of ct should be same as that of C. + Since update routines only support row-major order, + col_pref flag is used to induce transpose to matrices before + passing to update routine whenever C is col-stored */ + const bool col_pref = (rs_c == 1)? 1 : 0; + + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + dcomplex* restrict a_00 = a; + dcomplex* restrict b_00 = b; + dcomplex* restrict c_00 = c; + dcomplex* restrict alpha_cast = alpha; + dcomplex* restrict beta_cast = beta; + + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of cache lines between the cores' caches. */ \ + dcomplex beta_local = *beta_cast; + dcomplex one_local = *PASTEMAC(z,1); + + auxinfo_t aux; + + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ + + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); + bli_mem_clear( &mem_b ); + */ + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ + /* 5thloop 4thloop packb 3rdloop packa 2ndloop 1stloop ukrloop */ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t* restrict bszids; + + /* Set the bszids pointer to the correct bszids array above based on which + matrices (if any) are being packed. */ + if ( packa ) { if ( packb ) bszids = bszids_packab; + else bszids = bszids_packa; } + else { if ( packb ) bszids = bszids_packb; + else bszids = bszids_nopack; } + + /* Determine whether we are using more than one thread. */ + const bool is_mt = bli_rntm_calc_num_threads( rntm ); + + thrinfo_t* restrict thread_jc = NULL; + thrinfo_t* restrict thread_pc = NULL; + thrinfo_t* restrict thread_pb = NULL; + thrinfo_t* restrict thread_ic = NULL; + thrinfo_t* restrict thread_pa = NULL; + thrinfo_t* restrict thread_jr = NULL; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jc = bszids; + thread_jc = thread; + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); + + /* Compute the JC loop thread range for the current thread. */ + dim_t jc_start, jc_end; + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_LOWER, m, n, NR, FALSE, &jc_start, &jc_end ); + const dim_t n_local = jc_end - jc_start; + + /* Compute number of primary and leftover components of the JC loop. */ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ + const dim_t jc_left = n_local % NC; + + dim_t m_off_cblock, n_off_cblock; + dim_t m_off = 0; + dim_t n_off = 0; + doff_t diagoffc; + dim_t i, ip; + + /* Loop over the n dimension (NC rows/columns at a time). */ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) + { + /* Calculate the thread's current JC block dimension. */ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); + + dcomplex* restrict b_jc = b_00 + jj * jcstep_b; + dcomplex* restrict c_jc = c_00 + jj * jcstep_c; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_pc = &bszids_jc[1]; + thread_pc = bli_thrinfo_sub_node( thread_jc ); + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); + + /* Compute the PC loop thread range for the current thread. */ + const dim_t pc_start = 0, pc_end = k; + const dim_t k_local = k; + + /* Compute number of primary and leftover components of the PC loop. */ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ + const dim_t pc_left = k_local % KC; + + /* Loop over the k dimension (KC rows/columns at a time). */ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) + { + /* Calculate the thread's current PC block dimension. */ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); + + dcomplex* restrict a_pc = a_00 + pp * pcstep_a; + dcomplex* restrict b_pc = b_jc + pp * pcstep_b; + + /* Only apply beta to the first iteration of the pc loop. */ + dcomplex* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); + + m_off = 0; + n_off = jj; + diagoffc = m_off - n_off; + + dcomplex* b_use; + inc_t rs_b_use, cs_b_use, ps_b_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing B, we alias to + the _pc variables so that code further down can unconditionally + reference the _pb variables. Note that *if* we will be packing + B, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pb; + if ( packb ) { bszids_pb = &bszids_pc[1]; + thread_pb = bli_thrinfo_sub_node( thread_pc ); } + else { bszids_pb = &bszids_pc[0]; + thread_pb = thread_pc; } + + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then a_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ \ + PASTEMAC(z,packm_sup_b) + ( + packb, + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix B to */ + stor_id, /* a "panel of B." */ + BLIS_NO_TRANSPOSE, + KC, NC, /* This "panel of B" is (at most) KC x NC. */ + kc_cur, nc_cur, NR, + &one_local, + b_pc, rs_b, cs_b, + &b_use, &rs_b_use, &cs_b_use, + &ps_b_use, + cntx, + rntm, + &mem_b, + thread_pb + ); + + /* Alias a_use so that it's clear this is our current block of + matrix B. */ + dcomplex* restrict b_pc_use = b_use; + + /* We don't need to embed the panel stride of B within the auxinfo_t + object because this variant iterates through B in the jr loop, + which occurs here, within the macrokernel, not within the + millikernel. */ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_ic = &bszids_pb[1]; + thread_ic = bli_thrinfo_sub_node( thread_pb ); + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); + + /* Compute the IC loop thread range for the current thread. */ + dim_t ic_start, ic_end; + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_UPPER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); + const dim_t m_local = ic_end - ic_start; + + /* Compute number of primary and leftover components of the IC loop. */ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ + const dim_t ic_left = m_local % MC; + + /* Loop over the m dimension (MC rows at a time). */ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) + { + /* Calculate the thread's current IC block dimension. */ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); + dim_t nc_pruned = nc_cur; + + dcomplex* restrict a_ic = a_pc + ii * icstep_a; + dcomplex* restrict c_ic = c_jc + ii * icstep_c; + + m_off = ii; + + if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; + + diagoffc = m_off - n_off; + + if( diagoffc < 0 ) + { + ip = -diagoffc / MR; + i = ip * MR; + mc_cur = mc_cur - i; + diagoffc = -diagoffc % MR; + m_off += i; + c_ic = c_ic + ( i ) * rs_c; + a_ic = a_ic + ( i ) * rs_a; + } + + if( ( diagoffc + mc_cur ) < nc_cur ) + { + nc_pruned = diagoffc + mc_cur; + } + + dcomplex* a_use; + inc_t rs_a_use, cs_a_use, ps_a_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing A, we alias to + the _ic variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pa; + if ( packa ) { bszids_pa = &bszids_ic[1]; + thread_pa = bli_thrinfo_sub_node( thread_ic ); } + else { bszids_pa = &bszids_ic[0]; + thread_pa = thread_ic; } + + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ \ + PASTEMAC(z,packm_sup_a) + ( + packa, + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix A to */ + stor_id, /* a "block of A." */ + BLIS_NO_TRANSPOSE, + MC, KC, /* This "block of A" is (at most) MC x KC. */ + mc_cur, kc_cur, MR, + &one_local, + a_ic, rs_a, cs_a, + &a_use, &rs_a_use, &cs_a_use, + &ps_a_use, + cntx, + rntm, + &mem_a, + thread_pa + ); + + /* Alias a_use so that it's clear this is our current block of + matrix A. */ + dcomplex* restrict a_ic_use = a_use; + + /* Embed the panel stride of A within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of A (if needed). */ + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jr = &bszids_pa[1]; + thread_jr = bli_thrinfo_sub_node( thread_pa ); + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); + + /* Compute number of primary and leftover components of the JR loop. */ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; + dim_t jr_left = nc_pruned % NR; + + /* Compute the JR loop thread range for the current thread. */ + dim_t jr_start, jr_end; + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); + + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing B + since packing an extended edge case is not yet supported. */ + if ( !packb && !is_mt ) + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) + { + jr_iter--; jr_left += NR; + } + + /* Loop over the n dimension (NR columns at a time). */ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) + { + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); + + /* + dcomplex* restrict b_jr = b_pc_use + j * jrstep_b; + */ + dcomplex* restrict b_jr = b_pc_use + j * ps_b_use; + dcomplex* restrict c_jr = c_ic + j * jrstep_c; + + dim_t i; + dim_t m_zero = 0; + dim_t n_iter_zero = 0; + + m_off_cblock = m_off; + n_off_cblock = n_off + j * NR; + + if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) + { + m_zero = 0; + } + else + { + /* compute number of rows that are filled with zeroes and can be ignored */ + n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; + m_zero = n_iter_zero * MR; + } + + dcomplex* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; + dcomplex* restrict c_ir = c_jr + n_iter_zero * irstep_c; + + /* Ignore the zero region */ + m_off_cblock += m_zero; + + /* Compute the triangular part */ + for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) + { + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; + + LOWER_TRIANGLE_OPTIMIZATION_DCOMPLEX() + { + gemmsup_ker + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + zero, + ct, rs_ct, cs_ct, + &aux, + cntx + ); + if( col_pref ) + { + PASTEMAC(z,update_upper_triang)( n_off_cblock, m_off_cblock, + nr_cur, mr_cur, + ct, cs_ct, rs_ct, + beta_use, + c_ir, cs_c, rs_c ); + } + else + { + PASTEMAC(z,update_lower_triang)( m_off_cblock, n_off_cblock, + mr_cur, nr_cur, + ct, rs_ct, cs_ct, + beta_use, + c_ir, rs_c, cs_c ); + } + } + + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + } + + /* Invoke the gemmsup millikernel for remaining rectangular part. */ + gemmsup_ker + ( + conja, + conjb, + (i > mc_cur)? 0: mc_cur - i, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + beta_use, + c_ir, rs_c, cs_c, + &aux, + cntx + ); + + } + } + + /* NOTE: This barrier is only needed if we are packing B (since + that matrix is packed within the pc loop of this variant). */ + if ( packb ) bli_thread_barrier( thread_pb ); + } + } + + /* Release any memory that was acquired for packing matrices A and B. */ + PASTEMAC(z,packm_sup_finalize_mem_a) + ( + packa, + rntm, + &mem_a, + thread_pa + ); + PASTEMAC(z,packm_sup_finalize_mem_b) + ( + packb, + rntm, + &mem_b, + thread_pb + ); + +/* +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); +*/ +} + +void bli_zgemmtsup_u_ref_var2m + ( + bool packa, + bool packb, + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t stor_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ) +{ + const num_t dt = PASTEMAC(z,type); + + dcomplex* restrict zero = PASTEMAC(z,0); + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + + /* If k < 1 or alpha is zero, scale by beta and return. */ + if ( k < 1 || PASTEMAC(z,eq0)( *(( dcomplex* )alpha) ) ) + { + if ( bli_thread_am_ochief( thread ) ) + { + PASTEMAC(z,scalm) + ( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, rs_c, cs_c + ); + } + return; + } + + /* Query the context for various blocksizes. */ + dim_t NR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NR, cntx ); + dim_t MR = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MR, cntx ); + dim_t NC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_NC, cntx ); + dim_t MC = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_MC, cntx ); + dim_t KC0 = bli_cntx_get_l3_sup_tri_blksz_def_dt( dt, BLIS_KC, cntx ); + + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ + dim_t NRM = bli_cntx_get_l3_sup_tri_blksz_max_dt( dt, BLIS_NR, cntx ); + + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ + PASTECH(z,gemmsup_ker_ft) + gemmsup_ker = bli_cntx_get_l3_sup_tri_ker_dt( dt, stor_id, cntx ); + + if( ( 0 == NR ) || ( 0 == MR ) || ( 0 == NC ) || ( 0 == MC ) || ( 0 == KC0 ) ) + { + NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); + } + const dim_t NRE = NRM - NR; + + dim_t KC; + if ( packa && packb ) + { + KC = KC0; + } + else if ( packb ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else if ( packa ) + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = (( KC0 / 2 ) / 2 ) * 2; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR || + stor_id == BLIS_CCR ) KC = (( KC0 / 4 ) / 4 ) * 4; + else KC = KC0; + } + else /* if ( !packa && !packb ) */ + { + if ( stor_id == BLIS_RRR || + stor_id == BLIS_CCC ) KC = KC0; + else if ( stor_id == BLIS_RRC || + stor_id == BLIS_CRC ) KC = KC0; + else if ( stor_id == BLIS_RCR ) + { + if ( m <= 4*MR ) KC = KC0; + else if ( m <= 36*MR ) KC = KC0 / 2; + else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else KC = KC0 / 4; + } + else if ( m <= MR && n <= NR ) KC = KC0; + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; + else KC = (( KC0 / 5 ) / 4 ) * 4; + } + + /* Compute partitioning step values for each matrix of each loop. */ + const inc_t jcstep_c = cs_c; + const inc_t jcstep_b = cs_b; + + const inc_t pcstep_a = cs_a; + const inc_t pcstep_b = rs_b; + + const inc_t icstep_c = rs_c; + const inc_t icstep_a = rs_a; + + const inc_t jrstep_c = cs_c * NR; + + const inc_t irstep_c = rs_c * MR; + + /* + const inc_t jrstep_b = cs_b * NR; + ( void )jrstep_b; + + const inc_t irstep_c = rs_c * MR; + const inc_t irstep_a = rs_a * MR; + */ + + dcomplex ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( dcomplex ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + + /* Storage scheme of ct should be same as that of C. + Since update routines only support row-major order, + col_pref flag is used to induce transpose to matrices before + passing to update routine whenever C is col-stored */ + const bool col_pref = (rs_c == 1) ? 1 : 0; + + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + dcomplex* restrict a_00 = a; + dcomplex* restrict b_00 = b; + dcomplex* restrict c_00 = c; + dcomplex* restrict alpha_cast = alpha; + dcomplex* restrict beta_cast = beta; + + /* Make local copies of beta and one scalars to prevent any unnecessary + sharing of caze lines between the cores' cazes. */ + dcomplex beta_local = *beta_cast; + dcomplex one_local = *PASTEMAC(z,1); + + auxinfo_t aux; + + /* Parse and interpret the contents of the rntm_t object to properly + set the ways of parallelism for each loop. */ + /*bli_rntm_set_ways_from_rntm_sup( m, n, k, rntm );*/ + + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. An alternative way of initializing the + mem_t entries is: + + bli_mem_clear( &mem_a ); + bli_mem_clear( &mem_b ); + */ + mem_t mem_a = BLIS_MEM_INITIALIZER; + mem_t mem_b = BLIS_MEM_INITIALIZER; + + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ + /* 5thloop 4thloop packb 3rdloop packa 2ndloop 1stloop ukrloop */ + bszid_t bszids_nopack[6] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packa [7] = { BLIS_NC, BLIS_KC, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packb [7] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t bszids_packab[8] = { BLIS_NC, BLIS_KC, BLIS_NO_PART, BLIS_MC, BLIS_NO_PART, BLIS_NR, BLIS_MR, BLIS_KR }; + bszid_t* restrict bszids; + + /* Set the bszids pointer to the correct bszids array above based on whiz + matrices (if any) are being packed. */ + if ( packa ) { if ( packb ) bszids = bszids_packab; + else bszids = bszids_packa; } + else { if ( packb ) bszids = bszids_packb; + else bszids = bszids_nopack; } + + /* Determine whether we are using more than one thread. */ + const bool is_mt = bli_rntm_calc_num_threads( rntm ); + + thrinfo_t* restrict thread_jc = NULL; + thrinfo_t* restrict thread_pc = NULL; + thrinfo_t* restrict thread_pb = NULL; + thrinfo_t* restrict thread_ic = NULL; + thrinfo_t* restrict thread_pa = NULL; + thrinfo_t* restrict thread_jr = NULL; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jc = bszids; + thread_jc = thread; + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); + + /* Compute the JC loop thread range for the current thread. */ + dim_t jc_start, jc_end; + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_UPPER, m, n, NR, FALSE, &jc_start, &jc_end ); + const dim_t n_local = jc_end - jc_start; + + dim_t m_off = 0; + dim_t n_off = 0; + doff_t diagoffc; + dim_t m_off_cblock, n_off_cblock; + dim_t jp, j; + + /* Compute number of primary and leftover components of the JC loop. */ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ + const dim_t jc_left = n_local % NC; + + /* Loop over the n dimension (NC rows/columns at a time). */ + /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) + { + /* Calculate the thread's current JC block dimension. */ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); + + dcomplex* restrict b_jc = b_00 + jj * jcstep_b; + dcomplex* restrict c_jc = c_00 + jj * jcstep_c; + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_pc = &bszids_jc[1]; + thread_pc = bli_thrinfo_sub_node( thread_jc ); + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); + + /* Compute the PC loop thread range for the current thread. */ + const dim_t pc_start = 0, pc_end = k; + const dim_t k_local = k; + + /* Compute number of primary and leftover components of the PC loop. */ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ + const dim_t pc_left = k_local % KC; + + /* Loop over the k dimension (KC rows/columns at a time). */ + /*for ( dim_t pp = 0; pp < pc_iter; pp += 1 )*/ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) + { + /* Calculate the thread's current PC block dimension. */ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); + + dcomplex* restrict a_pc = a_00 + pp * pcstep_a; + dcomplex* restrict b_pc = b_jc + pp * pcstep_b; + + /* Only apply beta to the first iteration of the pc loop. */ + dcomplex* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); + + m_off = 0; + n_off = jj; + diagoffc = m_off - n_off; + + dcomplex* b_use; + inc_t rs_b_use, cs_b_use, ps_b_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing B, we alias to + the _pc variables so that code further down can unconditionally + reference the _pb variables. Note that *if* we will be packing + B, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pb; + if ( packb ) { bszids_pb = &bszids_pc[1]; + thread_pb = bli_thrinfo_sub_node( thread_pc ); } + else { bszids_pb = &bszids_pc[0]; + thread_pb = thread_pc; } + + /* Determine the packing buffer and related parameters for matrix + B. (If B will not be packed, then a_use will be set to point to + b and the _b_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(z,packm_sup_b) + ( + packb, + BLIS_BUFFER_FOR_B_PANEL, /* This algorithm packs matrix B to */ + stor_id, /* a "panel of B." */ + BLIS_NO_TRANSPOSE, + KC, NC, /* This "panel of B" is (at most) KC x NC. */ + kc_cur, nc_cur, NR, + &one_local, + b_pc, rs_b, cs_b, + &b_use, &rs_b_use, &cs_b_use, + &ps_b_use, + cntx, + rntm, + &mem_b, + thread_pb + ); + + /* Alias a_use so that it's clear this is our current block of + matrix B. */ + dcomplex* restrict b_pc_use = b_use; + + /* We don't need to embed the panel stride of B within the auxinfo_t + object because this variant iterates through B in the jr loop, + whiz occurs here, within the macrokernel, not within the + millikernel. */ + /*bli_auxinfo_set_ps_b( ps_b_use, &aux );*/ + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_ic = &bszids_pb[1]; + thread_ic = bli_thrinfo_sub_node( thread_pb ); + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); + + /* Compute the IC loop thread range for the current thread. */ + dim_t ic_start, ic_end; + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_LOWER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); + const dim_t m_local = ic_end - ic_start; + + /* Compute number of primary and leftover components of the IC loop. */ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ + const dim_t ic_left = m_local % MC; + + /* Loop over the m dimension (MC rows at a time). */ + /*for ( dim_t ii = 0; ii < ic_iter; ii += 1 )*/ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) + { + /* Calculate the thread's current IC block dimension. */ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); + + dim_t nc_pruned = nc_cur; + + m_off = ii; + n_off = jj; + + if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; + + dcomplex* restrict a_ic = a_pc + ii * icstep_a; + dcomplex* restrict c_ic = c_jc + ii * icstep_c; + + doff_t diagoffc = m_off - n_off; + + dcomplex* restrict b_pc_pruned = b_pc_use; + + if(diagoffc > 0 ) + { + jp = diagoffc / NR; + j = jp * NR; + nc_pruned = nc_cur - j; + n_off += j; + diagoffc = diagoffc % NR; + c_ic = c_ic + ( j ) * cs_c; + b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; + } + + if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) + { + mc_cur = -diagoffc + nc_pruned; + } + + dcomplex* a_use; + inc_t rs_a_use, cs_a_use, ps_a_use; + + /* Set the bszid_t array and thrinfo_t pointer based on whether + we will be packing B. If we won't be packing A, we alias to + the _ic variables so that code further down can unconditionally + reference the _pa variables. Note that *if* we will be packing + A, the thrinfo_t node will have already been created by a + previous call to bli_thrinfo_grow(), since bszid values of + BLIS_NO_PART cause the tree to grow by two (e.g. to the next + bszid that is a normal bszid_t value). */ + bszid_t* restrict bszids_pa; + if ( packa ) { bszids_pa = &bszids_ic[1]; + thread_pa = bli_thrinfo_sub_node( thread_ic ); } + else { bszids_pa = &bszids_ic[0]; + thread_pa = thread_ic; } + + /* Determine the packing buffer and related parameters for matrix + A. (If A will not be packed, then a_use will be set to point to + a and the _a_use strides will be set accordingly.) Then call + the packm sup variant chooser, which will call the appropriate + implementation based on the schema deduced from the stor_id. */ + PASTEMAC(z,packm_sup_a) + ( + packa, + BLIS_BUFFER_FOR_A_BLOCK, /* This algorithm packs matrix A to */ + stor_id, /* a "block of A." */ + BLIS_NO_TRANSPOSE, + MC, KC, /* This "block of A" is (at most) MC x KC. */ + mc_cur, kc_cur, MR, + &one_local, + a_ic, rs_a, cs_a, + &a_use, &rs_a_use, &cs_a_use, + &ps_a_use, + cntx, + rntm, + &mem_a, + thread_pa + ); + + /* Alias a_use so that it's clear this is our current block of + matrix A. */ + dcomplex* restrict a_ic_use = a_use; + + /* Embed the panel stride of A within the auxinfo_t object. The + millikernel will query and use this to iterate through + micropanels of A (if needed). */ + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + /* Grow the thrinfo_t tree. */ + bszid_t* restrict bszids_jr = &bszids_pa[1]; + thread_jr = bli_thrinfo_sub_node( thread_pa ); + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); + + /* Compute number of primary and leftover components of the JR loop. */ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; + dim_t jr_left = nc_pruned % NR; + + /* Compute the JR loop thread range for the current thread. */ + dim_t jr_start, jr_end; + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); + + /* An optimization: allow the last jr iteration to contain up to NRE + columns of C and B. (If NRE > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. NOTE: We forgo this optimization when packing B + since packing an extended edge case is not yet supported. */ + if ( !packb && !is_mt ) + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) + { + jr_iter--; jr_left += NR; + } + + /* Loop over the n dimension (NR columns at a time). */ + /*for ( dim_t j = 0; j < jr_iter; j += 1 )*/ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) + { + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); + + /* + dcomplex* restrict b_jr = b_pc_use + j * jrstep_b; + */ + dcomplex* restrict b_jr = b_pc_pruned + j * ps_b_use; + dcomplex* restrict c_jr = c_ic + j * jrstep_c; + dim_t m_rect = 0; + dim_t n_iter_rect = 0; + + m_off_cblock = m_off; + n_off_cblock = n_off + j * NR; + + if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) + { + m_rect = mc_cur; + } + else + { + /* calculate the number of rows in rectangular region of the block */ + n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; + m_rect = n_iter_rect * MR; + } + + /* Compute the rectangular part */ + gemmsup_ker + ( + conja, + conjb, + m_rect, + nr_cur, + kc_cur, + alpha_cast, + a_ic_use, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + beta_use, + c_jr, rs_c, cs_c, + &aux, + cntx + ); + + m_off_cblock = m_off + m_rect; + + dcomplex* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; + dcomplex* restrict c_ir = c_jr + n_iter_rect * irstep_c; + + /* compute the remaining triangular part */ + for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) + { + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; + UPPER_TRIANGLE_OPTIMIZATION_DCOMPLEX() + { + gemmsup_ker + ( + conja, + conjb, + mr_cur, + nr_cur, + kc_cur, + alpha_cast, + a_ir, rs_a_use, cs_a_use, + b_jr, rs_b_use, cs_b_use, + zero, + ct, rs_ct, cs_ct, + &aux, + cntx + ); + + if( col_pref ) + { + PASTEMAC(z,update_lower_triang)( n_off_cblock, m_off_cblock, + nr_cur, mr_cur, + ct, cs_ct, rs_ct, + beta_use, + c_ir, cs_c, rs_c ); + } + else + { + PASTEMAC(z,update_upper_triang)( m_off_cblock, n_off_cblock, + mr_cur, nr_cur, + ct, rs_ct, cs_ct, + beta_use, + c_ir, rs_c, cs_c ); + } + } + + a_ir += ps_a_use; + c_ir += irstep_c; + m_off_cblock += mr_cur; + + } + } + } + + /* NOTE: This barrier is only needed if we are packing B (since + that matrix is packed within the pc loop of this variant). */ + if ( packb ) bli_thread_barrier( thread_pb ); + } + } + + /* Release any memory that was acquired for packing matrices A and B. */ + PASTEMAC(z,packm_sup_finalize_mem_a) + ( + packa, + rntm, + &mem_a, + thread_pa + ); + PASTEMAC(z,packm_sup_finalize_mem_b) + ( + packb, + rntm, + &mem_b, + thread_pb + ); + +/* +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); +PASTEMAC(z,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); +*/ +} diff --git a/frame/3/hemm/bli_hemm_front.c b/frame/3/hemm/bli_hemm_front.c index a9878e0f9e..a9dd543511 100644 --- a/frame/3/hemm/bli_hemm_front.c +++ b/frame/3/hemm/bli_hemm_front.c @@ -54,10 +54,6 @@ void bli_hemm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_hemm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -131,6 +127,9 @@ void bli_hemm_front } #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -151,17 +150,6 @@ void bli_hemm_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/her2k/bli_her2k_front.c b/frame/3/her2k/bli_her2k_front.c index 39d4dfc0d6..dbe672eb6f 100644 --- a/frame/3/her2k/bli_her2k_front.c +++ b/frame/3/her2k/bli_her2k_front.c @@ -56,10 +56,6 @@ void bli_her2k_front obj_t b_local; obj_t ah_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_her2k_check( alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta, zero the imaginary components of // the diagonal elements, and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -84,12 +80,6 @@ void bli_her2k_front bli_obj_induce_trans( &ah_local ); bli_obj_toggle_conj( &ah_local ); - // Initialize a conjugated copy of alpha. - bli_obj_scalar_init_detached_copy_of( bli_obj_dt( a ), - BLIS_CONJUGATE, - alpha, - &alpha_conj ); - // An optimization: If C is stored by rows and the micro-kernel prefers // contiguous columns, or if C is stored by columns and the micro-kernel // prefers contiguous rows, transpose the entire operation to allow the @@ -107,6 +97,16 @@ void bli_her2k_front bli_obj_induce_trans( &c_local ); } + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &bh_local, &c_local, cntx ); + bli_l3_set_schemas( &b_local, &ah_local, &c_local, cntx ); + + // Initialize a conjugated copy of alpha. + bli_obj_scalar_init_detached_copy_of( bli_obj_dt( a ), + BLIS_CONJUGATE, + alpha, + &alpha_conj ); + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -120,19 +120,6 @@ void bli_her2k_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &bh_local ); - bli_obj_set_pack_schema( schema_a, &b_local ); - bli_obj_set_pack_schema( schema_b, &ah_local ); - // Invoke herk twice, using beta only the first time. // Invoke the internal back-end. diff --git a/frame/3/herk/bli_herk_front.c b/frame/3/herk/bli_herk_front.c index 9ba19b3a36..ffdef07d67 100644 --- a/frame/3/herk/bli_herk_front.c +++ b/frame/3/herk/bli_herk_front.c @@ -52,10 +52,6 @@ void bli_herk_front obj_t ah_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_herk_check( alpha, a, beta, c, cntx ); - // If alpha is zero, scale by beta, zero the imaginary components of // the diagonal elements, and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -87,6 +83,9 @@ void bli_herk_front bli_obj_induce_trans( &c_local ); } + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &ah_local, &c_local, cntx ); + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -100,17 +99,6 @@ void bli_herk_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &ah_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/herk/bli_herk_l_ker_var2.c b/frame/3/herk/bli_herk_l_ker_var2.c index 1f5f544b45..d2f9ee2dbb 100644 --- a/frame/3/herk/bli_herk_l_ker_var2.c +++ b/frame/3/herk/bli_herk_l_ker_var2.c @@ -279,9 +279,6 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the diff --git a/frame/3/herk/bli_herk_u_ker_var2.c b/frame/3/herk/bli_herk_u_ker_var2.c index d84fadecb9..df29cb9a6c 100644 --- a/frame/3/herk/bli_herk_u_ker_var2.c +++ b/frame/3/herk/bli_herk_u_ker_var2.c @@ -281,9 +281,6 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the diff --git a/frame/3/symm/bli_symm_front.c b/frame/3/symm/bli_symm_front.c index a395a1c1c6..f56b238d55 100644 --- a/frame/3/symm/bli_symm_front.c +++ b/frame/3/symm/bli_symm_front.c @@ -54,10 +54,6 @@ void bli_symm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_symm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -130,6 +126,9 @@ void bli_symm_front } #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -150,17 +149,6 @@ void bli_symm_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/syr2k/bli_syr2k_front.c b/frame/3/syr2k/bli_syr2k_front.c index dfd1f575a5..3999c7cf4d 100644 --- a/frame/3/syr2k/bli_syr2k_front.c +++ b/frame/3/syr2k/bli_syr2k_front.c @@ -55,10 +55,6 @@ void bli_syr2k_front obj_t b_local; obj_t at_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_syr2k_check( alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -88,6 +84,10 @@ void bli_syr2k_front bli_obj_induce_trans( &c_local ); } + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &bt_local, &c_local, cntx ); + bli_l3_set_schemas( &b_local, &at_local, &c_local, cntx ); + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -101,19 +101,6 @@ void bli_syr2k_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &bt_local ); - bli_obj_set_pack_schema( schema_a, &b_local ); - bli_obj_set_pack_schema( schema_b, &at_local ); - // Invoke herk twice, using beta only the first time. // Invoke the internal back-end. diff --git a/frame/3/syrk/bli_syrk_front.c b/frame/3/syrk/bli_syrk_front.c index d0b2a14f1b..e773bc4f45 100644 --- a/frame/3/syrk/bli_syrk_front.c +++ b/frame/3/syrk/bli_syrk_front.c @@ -61,9 +61,13 @@ void bli_syrk_front bli_obj_alias_to( a, &at_local ); bli_obj_induce_trans( &at_local ); - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_syrk_check( alpha, a, beta, c, cntx ); +#if 0 +#ifdef BLIS_ENABLE_SMALL_MATRIX + gint_t status = bli_syrk_small( alpha, &a_local, &at_local, beta, &c_local, + cntx, cntl ); + if ( status == BLIS_SUCCESS ) return; +#endif +#endif // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -81,6 +85,9 @@ void bli_syrk_front bli_obj_induce_trans( &c_local ); } + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &at_local, &c_local, cntx ); + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -94,17 +101,6 @@ void bli_syrk_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &at_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/trmm/bli_trmm_front.c b/frame/3/trmm/bli_trmm_front.c index 852e3fdef7..f941045e4a 100644 --- a/frame/3/trmm/bli_trmm_front.c +++ b/frame/3/trmm/bli_trmm_front.c @@ -52,10 +52,6 @@ void bli_trmm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -148,6 +144,9 @@ void bli_trmm_front #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -168,17 +167,6 @@ void bli_trmm_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/trmm/bli_trmm_front_amd.c b/frame/3/trmm/bli_trmm_front_amd.c index 534df6fec5..287f0ced42 100644 --- a/frame/3/trmm/bli_trmm_front_amd.c +++ b/frame/3/trmm/bli_trmm_front_amd.c @@ -54,7 +54,7 @@ void bli_trmm_front // Check parameters. if ( bli_error_checking_is_enabled() ) - bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + bli_trmm_check( side, alpha, a, b, cntx ); // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -148,6 +148,9 @@ void bli_trmm_front #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -177,17 +180,6 @@ void bli_trmm_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index e7cdd5f1f8..564f76cf4e 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -203,9 +203,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -243,30 +240,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else if ( bli_is_rih_packed( schema_a ) ) { ss_a_num = 1; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region above where the diagonal of A intersects the left edge of the block, adjust the pointer to C and treat this case as @@ -317,9 +290,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -387,12 +357,12 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1011 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* NOTE: ir loop parallelism disabled for now. */ \ /*if ( bli_trmm_my_iter( i, ir_thread ) ) {*/ \ \ - b1_i = b1 + ( off_a1011 * PACKNR ) / off_scl; \ + b1_i = b1 + off_a1011 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -408,10 +378,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ @@ -479,10 +445,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 7cf21b07f0..56b848fc0f 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -203,9 +203,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -243,30 +240,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else if ( bli_is_rih_packed( schema_a ) ) { ss_a_num = 1; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of A intersects the top edge of the block, adjust the pointer to B and @@ -278,7 +251,7 @@ void PASTEMAC(ch,varname) \ i = diagoffa; \ k = k - i; \ diagoffa = 0; \ - b_cast = b_cast + ( i * PACKNR ) / off_scl; \ + b_cast = b_cast + i * PACKNR; \ } \ \ /* If there is a zero region below where the diagonal of A intersects the @@ -324,9 +297,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -394,12 +364,12 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1112 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* NOTE: ir loop parallelism disabled for now. */ \ /*if ( bli_trmm_my_iter( i, ir_thread ) ) {*/ \ \ - b1_i = b1 + ( off_a1112 * PACKNR ) / off_scl; \ + b1_i = b1 + off_a1112 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -415,10 +385,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ @@ -486,10 +452,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 737b7d2ed2..cd3c9696f3 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -203,9 +203,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -243,30 +240,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else if ( bli_is_rih_packed( schema_b ) ) { ss_b_num = 1; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region above where the diagonal of B intersects the left edge of the panel, adjust the pointer to A and treat this @@ -278,7 +251,7 @@ void PASTEMAC(ch,varname) \ j = -diagoffb; \ k = k - j; \ diagoffb = 0; \ - a_cast = a_cast + ( j * PACKMR ) / off_scl; \ + a_cast = a_cast + j * PACKMR; \ } \ \ /* If there is a zero region to the right of where the diagonal @@ -324,9 +297,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of A to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ \ @@ -387,10 +357,6 @@ void PASTEMAC(ch,varname) \ b2 = b1; \ \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = ir_start; i < ir_end; i += ir_inc ) \ { \ @@ -504,13 +470,9 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_b_cur = k_b1121 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ + ps_b_cur = is_b_cur; \ \ if ( bli_trmm_my_iter_rr( j, thread ) ) { \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( is_b_cur, &aux ); \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -522,7 +484,7 @@ void PASTEMAC(ch,varname) \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ - a1_i = a1 + ( off_b1121 * PACKMR ) / off_scl; \ + a1_i = a1 + off_b1121 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index 622e968c55..8a6e87ee0b 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -203,9 +203,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -243,30 +240,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = k; \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. And if we are packing real-only, imag-only, or - summed-only, we need to scale the computed panel sizes by 1/2 - to compensate for the fact that the pointer arithmetic occurs - in terms of complex elements rather than real elements. */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else if ( bli_is_rih_packed( schema_b ) ) { ss_b_num = 1; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of B intersects the top edge of the panel, adjust the pointer to C and @@ -325,9 +298,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of A to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -409,13 +379,9 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_b_cur = k_b0111 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ + ps_b_cur = is_b_cur; \ \ if ( bli_trmm_my_iter_rr( j, thread ) ) { \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( is_b_cur, &aux ); \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -427,7 +393,7 @@ void PASTEMAC(ch,varname) \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ - a1_i = a1 + ( off_b0111 * PACKMR ) / off_scl; \ + a1_i = a1 + off_b0111 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -542,10 +508,6 @@ void PASTEMAC(ch,varname) \ This allows the current macro-kernel to work for both trmm and trmm3. */ \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = ir_start; i < ir_end; i += ir_inc ) \ { \ diff --git a/frame/3/trmm3/bli_trmm3_front.c b/frame/3/trmm3/bli_trmm3_front.c index 9042d1478d..c27c3ad5ae 100644 --- a/frame/3/trmm3/bli_trmm3_front.c +++ b/frame/3/trmm3/bli_trmm3_front.c @@ -54,10 +54,6 @@ void bli_trmm3_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trmm_check( side, alpha, a, b, beta, c, cntx ); - // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) { @@ -141,6 +137,9 @@ void bli_trmm3_front #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -161,17 +160,6 @@ void bli_trmm3_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index 080b9713f0..c056028f27 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -56,9 +56,12 @@ void bli_trsm_front obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - bli_trsm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); +#if 0 +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + gint_t status = bli_trsm_small( side, alpha, a, b, cntx, cntl ); + if ( status == BLIS_SUCCESS ) return; +#endif +#endif // If alpha is zero, scale by beta and return. if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) @@ -120,6 +123,9 @@ void bli_trsm_front #endif + // Set the pack schemas within the objects. + bli_l3_set_schemas( &a_local, &b_local, &c_local, cntx ); + // Set each alias as the root object. // NOTE: We MUST wait until we are done potentially swapping the objects // before setting the root fields! @@ -145,25 +151,6 @@ void bli_trsm_front rntm ); - // A sort of hack for communicating the desired pack schemas for A and B - // to bli_trsm_cntl_create() (via bli_l3_thread_decorator() and - // bli_l3_cntl_create_if()). This allows us to access the schemas from - // the control tree, which hopefully reduces some confusion, particularly - // in bli_packm_init(). - if ( bli_cntx_method( cntx ) == BLIS_NAT ) - { - bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); - bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); - } - else // if ( bli_cntx_method( cntx_trsm ) != BLIS_NAT ) - { - pack_t schema_a = bli_cntx_schema_a_block( cntx ); - pack_t schema_b = bli_cntx_schema_b_panel( cntx ); - - bli_obj_set_pack_schema( schema_a, &a_local ); - bli_obj_set_pack_schema( schema_b, &b_local ); - } - // Invoke the internal back-end. bli_l3_thread_decorator ( diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 279df5277a..a1a585cb3c 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -217,9 +217,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -265,29 +262,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region above where the diagonal of A intersects the left edge of the block, adjust the pointer to C and treat this case as @@ -355,9 +329,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* We don't bother querying the thrinfo_t node for the 1st loop because we can't parallelize that loop in trsm due to the inter-iteration @@ -427,18 +398,18 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1011 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* Compute the addresses of the panel A10 and the triangular block A11. */ \ a10 = a1; \ - /* a11 = a1 + ( k_a10 * PACKMR ) / off_scl; */ \ - a11 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a10 * PACKMR, off_scl ); \ + a11 = a1 + k_a10 * PACKMR; \ + /*a11 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a10 * PACKMR, 1 );*/ \ \ /* Compute the addresses of the panel B01 and the block B11. */ \ - b01 = b1 + ( off_a10 * PACKNR ) / off_scl; \ - b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \ + b01 = b1 + off_a10 * PACKNR; \ + b11 = b1 + off_a11 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + ps_a_cur; \ @@ -454,10 +425,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ @@ -518,10 +485,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ @@ -569,44 +532,11 @@ void PASTEMAC(ch,varname) \ } \ \ /* -if ( bli_is_4mi_packed( schema_a ) ){ \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_r before", k, n, \ - ( double* )b, rs_b, 1, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_i before", k, n, \ - ( double* )b+72, rs_b, 1, "%4.1f", "" ); \ -}else{ \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_r before", k, n, \ - ( double* )b, 2*rs_b, 2, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_i before", k, n, \ - ( double* )b+1, 2*rs_b, 2, "%4.1f", "" ); \ -} \ -*/ \ -\ -/* PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: a11p_r computed", MR, MR, \ ( double* )a11, 1, PACKMR, "%4.1f", "" ); \ */ \ \ /* -if ( bli_is_4mi_packed( schema_a ) ){ \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_r after", k, n, \ - ( double* )b, rs_b, 1, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm4m1_ll_ker_var2: b_i after", k, n, \ - ( double* )b+72, rs_b, 1, "%4.1f", "" ); \ -}else{ \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_r after", k, n, \ - ( double* )b, 2*rs_b, 2, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsmnat_ll_ker_var2: b_i after", k, n, \ - ( double* )b+1, 2*rs_b, 2, "%4.1f", "" ); \ -} \ - -PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: b_r", m, n, \ - ( double* )c, 1, cs_c, "%4.1f", "" ); \ -PASTEMAC(d,fprintm)( stdout, "trsm_ll_ker_var2: b_i", m, n, \ - ( double* )c + 8*9, 1, cs_c, "%4.1f", "" ); \ -*/ \ -\ -/* PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: a1 (diag)", MR, k_a1011, a1, 1, MR, "%5.2f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: a11 (diag)", MR, MR, a11, 1, MR, "%5.2f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "trsm_ll_ker_var2: b1 (diag)", k_a1011, NR, bp_i, NR, 1, "%5.2f", "" ); \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 5a68106a78..d4e3cc2661 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -218,9 +218,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_a_num; \ - inc_t ss_a_den; \ inc_t ps_a_cur; \ inc_t is_a_cur; \ auxinfo_t aux; \ @@ -266,29 +263,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_a ) || \ - bli_is_3mi_packed( schema_a ) || \ - bli_is_rih_packed( schema_a ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_a ) ) { ss_a_num = 3; ss_a_den = 2; } \ - else { ss_a_num = 1; ss_a_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of A intersects the top edge of the block, adjust the pointer to B and @@ -300,7 +274,7 @@ void PASTEMAC(ch,varname) \ i = diagoffa; \ k = k - i; \ diagoffa = 0; \ - b_cast = b_cast + ( i * PACKNR ) / off_scl; \ + b_cast = b_cast + i * PACKNR; \ } \ \ /* If there is a zero region below where the diagonal of A intersects the @@ -363,9 +337,6 @@ void PASTEMAC(ch,varname) \ \ /* Save the imaginary stride of B to the auxinfo_t object. */ \ bli_auxinfo_set_is_b( istep_b, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* We don't bother querying the thrinfo_t node for the 1st loop because we can't parallelize that loop in trsm due to the inter-iteration @@ -437,18 +408,18 @@ void PASTEMAC(ch,varname) \ intersecting micro-panel. */ \ is_a_cur = k_a1112 * PACKMR; \ is_a_cur += ( bli_is_odd( is_a_cur ) ? 1 : 0 ); \ - ps_a_cur = ( is_a_cur * ss_a_num ) / ss_a_den; \ + ps_a_cur = is_a_cur; \ \ /* Compute the addresses of the triangular block A11 and the panel A12. */ \ a11 = a1; \ - /* a12 = a1 + ( k_a11 * PACKMR ) / off_scl; */ \ - a12 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a11 * PACKMR, off_scl ); \ + a12 = a1 + k_a11 * PACKMR; \ + /*a12 = bli_ptr_inc_by_frac( a1, sizeof( ctype ), k_a11 * PACKMR, 1 );*/ \ \ /* Compute the addresses of the panel B01 and the block B11. */ \ - b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \ - b21 = b1 + ( off_a12 * PACKNR ) / off_scl; \ + b11 = b1 + off_a11 * PACKNR; \ + b21 = b1 + off_a12 * PACKNR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + ps_a_cur; \ @@ -464,10 +435,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( is_a_cur, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ @@ -528,10 +495,6 @@ void PASTEMAC(ch,varname) \ object. */ \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Save the 4m1/3m1 imaginary stride of A to the auxinfo_t - object. */ \ - bli_auxinfo_set_is_a( istep_a, &aux ); \ \ /* Handle interior and edge cases separately. */ \ if ( m_cur == MR && n_cur == NR ) \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 507785c6cc..c66adfa193 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -224,9 +224,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -280,29 +277,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region above where the diagonal of B intersects the left edge of the panel, adjust the pointer to A and treat this @@ -314,7 +288,7 @@ void PASTEMAC(ch,varname) \ j = -diagoffb; \ k = k - j; \ diagoffb = 0; \ - a_cast = a_cast + ( j * PACKMR ) / off_scl; \ + a_cast = a_cast + j * PACKMR; \ } \ \ /* If there is a zero region to the right of where the diagonal @@ -386,9 +360,6 @@ void PASTEMAC(ch,varname) \ NOTE: We swap the values for A and B since the triangular "A" matrix is actually contained within B. */ \ bli_auxinfo_set_is_b( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ b1 = b_cast; \ c1 = c_cast; \ @@ -430,20 +401,14 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the triangular block B11 and the panel B21. */ \ - b11 = b1; \ - /* b21 = b1 + ( k_b11 * PACKNR ) / off_scl; */ \ - b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, off_scl ); \ + b11 = b1; \ + b21 = b1 + k_b11 * PACKNR; \ + /*b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 );*/ \ \ /* Compute the panel stride for the current micro-panel. */ \ is_b_cur = k_b1121 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( is_b_cur, &aux ); \ + ps_b_cur = is_b_cur; \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -457,8 +422,8 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the A11 block and A12 panel. */ \ - a11 = a1 + ( off_b11 * PACKMR ) / off_scl; \ - a12 = a1 + ( off_b21 * PACKMR ) / off_scl; \ + a11 = a1 + off_b11 * PACKMR; \ + a12 = a1 + off_b21 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -525,12 +490,6 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_below_diag_n( diagoffb_j, k, NR ) ) \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ { \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 53ad570f60..9c3e690df0 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -222,9 +222,6 @@ void PASTEMAC(ch,varname) \ inc_t rstep_c, cstep_c; \ inc_t istep_a; \ inc_t istep_b; \ - inc_t off_scl; \ - inc_t ss_b_num; \ - inc_t ss_b_den; \ inc_t ps_b_cur; \ inc_t is_b_cur; \ auxinfo_t aux; \ @@ -278,29 +275,6 @@ void PASTEMAC(ch,varname) \ matrix), which is used by 4m1/3m1 implementations, we need this unreduced value of k. */ \ k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ -\ - /* Compute indexing scaling factor for for 4m or 3m. This is - needed because one of the packing register blocksizes (PACKMR - or PACKNR) is used to index into the micro-panels of the non- - triangular matrix when computing with a diagonal-intersecting - micro-panel of the triangular matrix. In the case of 4m or 3m, - real values are stored in both sub-panels, and so the indexing - needs to occur in units of real values. The value computed - here is divided into the complex pointer offset to cause the - pointer to be advanced by the correct value. */ \ - if ( bli_is_4mi_packed( schema_b ) || \ - bli_is_3mi_packed( schema_b ) || \ - bli_is_rih_packed( schema_b ) ) off_scl = 2; \ - else off_scl = 1; \ -\ - /* Compute the storage stride scaling. Usually this is just 1. - However, in the case of interleaved 3m, we need to scale the - offset by 3/2. Note that real-only, imag-only, and summed-only - packing formats are not applicable here since trsm is a two- - operand operation only (unlike trmm, which is capable of three- - operand). */ \ - if ( bli_is_3mi_packed( schema_b ) ) { ss_b_num = 3; ss_b_den = 2; } \ - else { ss_b_num = 1; ss_b_den = 1; } \ \ /* If there is a zero region to the left of where the diagonal of B intersects the top edge of the panel, adjust the pointer to C and @@ -380,9 +354,6 @@ void PASTEMAC(ch,varname) \ NOTE: We swap the values for A and B since the triangular "A" matrix is actually contained within B. */ \ bli_auxinfo_set_is_b( istep_a, &aux ); \ -\ - /* Save the desired output datatype (indicating no typecasting). */ \ - /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ b1 = b_cast; \ c1 = c_cast; \ @@ -422,20 +393,14 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the panel B10 and the triangular block B11. */ \ - b01 = b1; \ - /* b11 = b1 + ( k_b01 * PACKNR ) / off_scl; */ \ - b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, off_scl ); \ + b01 = b1; \ + b11 = b1 + k_b01 * PACKNR; \ + /*b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 );*/ \ \ /* Compute the panel stride for the current micro-panel. */ \ is_b_cur = k_b0111 * PACKNR; \ is_b_cur += ( bli_is_odd( is_b_cur ) ? 1 : 0 ); \ - ps_b_cur = ( is_b_cur * ss_b_num ) / ss_b_den; \ -\ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( is_b_cur, &aux ); \ + ps_b_cur = is_b_cur; \ \ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ @@ -449,8 +414,8 @@ void PASTEMAC(ch,varname) \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the A10 panel and A11 block. */ \ - a10 = a1 + ( off_b01 * PACKMR ) / off_scl; \ - a11 = a1 + ( off_b11 * PACKMR ) / off_scl; \ + a10 = a1 + off_b01 * PACKMR; \ + a11 = a1 + off_b11 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1; \ @@ -517,12 +482,6 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_above_diag_n( diagoffb_j, k, NR ) ) \ { \ - /* Save the 4m1/3m1 imaginary stride of B to the auxinfo_t - object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ \ - bli_auxinfo_set_is_a( istep_b, &aux ); \ -\ /* Loop over the m dimension (MR rows at a time). */ \ for ( i = 0; i < m_iter; ++i ) \ { \ diff --git a/frame/CMakeLists.txt b/frame/CMakeLists.txt index 29070ae1a1..524ac64e93 100644 --- a/frame/CMakeLists.txt +++ b/frame/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Collect all subdirectory paths that have at least one file with suffix in FRAME_SRC_SUFS list. get_filepaths_with_suffixes(LOCAL_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR} "${FRAME_SRC_SUFS}") @@ -91,10 +123,8 @@ elseif(THREADING_MODEL STREQUAL "pthreads") # in get-noopt-cflags-for target_compile_options(FRAME PRIVATE ${CTHREADFLAGS}) endif() -if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(FRAME PROPERTIES POSITION_INDEPENDENT_CODE ON) -endif() +# Equivalent to CPICFLAGS in get-noopt-cflags-for +set_target_properties(FRAME PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(FRAME flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(FRAME PROPERTIES FOLDER object-libs-targets) diff --git a/frame/base/CMakeLists.txt b/frame/base/CMakeLists.txt index 798a642fe5..84ae518306 100644 --- a/frame/base/CMakeLists.txt +++ b/frame/base/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] target_sources("${PROJECT_NAME}" PUBLIC diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index e4d4edfbac..3c93a48737 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -91,7 +91,6 @@ bool bli_aocl_enable_instruction_query( void ) arch_t bli_arch_query_id( void ) { - bli_arch_set_id_once(); bli_arch_check_id_once(); // Simply return the id that was previously cached. @@ -100,7 +99,6 @@ arch_t bli_arch_query_id( void ) model_t bli_model_query_id( void ) { - bli_arch_set_id_once(); bli_arch_check_id_once(); // Simply return the model_id that was previously cached. @@ -179,23 +177,41 @@ void bli_arch_set_id( void ) #ifndef BLIS_CONFIGURETIME_CPUID if ( req_id != -1 ) { - // BLIS_ARCH_TYPE was set. Cautiously check whether its value is usable. + // BLIS_ARCH_TYPE and/or AOCL_ENABLE_INSTRUCTIONS was set. + // Cautiously check whether its value is usable. - // If req_id was set to an invalid arch_t value (ie: outside the range - // [1,BLIS_NUM_ARCHS-1]), output an error message and abort. + // Test if req_id was set to an invalid arch_t value (ie: outside the range + // [1,BLIS_NUM_ARCHS-1]), and handle appropriately depending on how it was set. if ( bli_error_checking_is_enabled() ) { err_t e_val = bli_check_valid_arch_id( req_id ); - bli_check_error_code( e_val ); + if (aocl_e_i) + { + // AOCL_ENABLE_INSTRUCTIONS was used: + // If req_id is invalid, ignore user supplied + // value and reset to -1 so we'll use normal + // subconfig selection below. + if ( e_val != BLIS_SUCCESS ) + req_id = -1; + } + else + { + // BLIS_ARCH_TYPE was used: + // Abort on invalid value. + bli_check_error_code( e_val ); + } } + } + if ( req_id != -1 ) + { // Check again context actually initialized deferred to // bli_arch_check_id() called later. // For now, we can only be confident that req_id is in range. arch_id = req_id; - } - else + } + else #endif #endif @@ -234,6 +250,9 @@ void bli_arch_set_id( void ) #endif // AMD microarchitectures. + #ifdef BLIS_FAMILY_ZEN5 + arch_id = BLIS_ARCH_ZEN5; + #endif #ifdef BLIS_FAMILY_ZEN4 arch_id = BLIS_ARCH_ZEN4; #endif @@ -266,6 +285,9 @@ void bli_arch_set_id( void ) #ifdef BLIS_FAMILY_A64FX arch_id = BLIS_ARCH_A64FX; #endif + #ifdef BLIS_FAMILY_FIRESTORM + id = BLIS_ARCH_FIRESTORM; + #endif #ifdef BLIS_FAMILY_THUNDERX2 arch_id = BLIS_ARCH_THUNDERX2; #endif @@ -356,6 +378,7 @@ void bli_arch_check_id( void ) { bli_arch_set_id_once(); + bool arch_not_in_build = FALSE; bool arch_reset = FALSE; arch_t orig_arch_id= req_id; model_t orig_model_id = model_id; @@ -376,106 +399,106 @@ void bli_arch_check_id( void ) #ifndef BLIS_CONFIGURETIME_CPUID if ( req_id != -1 ) { - // BLIS_ARCH_TYPE was set. Cautiously check whether its value is usable. - // In BLAS1 and BLAS2 routines, bli_init_auto() may not have been // called, so ensure cntx has been initialized here. bli_gks_init_once(); - bool test_arch = TRUE; - while (test_arch) - { + // At this point, we know that req_id is in the valid range, but we + // don't yet know if it refers to a context that was actually + // initialized. Query the address of an internal context data structure + // corresponding to req_id. This pointer will be NULL if the associated + // subconfig is not available. + cntx_t** req_cntx = bli_gks_lookup_id( req_id ); - // At this point, we know that req_id is in the valid range, but we - // don't yet know if it refers to a context that was actually - // initialized. Query the address of an internal context data structure - // corresponding to req_id. This pointer will be NULL if the associated - // subconfig is not available. - cntx_t** req_cntx = bli_gks_lookup_id( req_id ); + if ( aocl_e_i ) + { + // AOCL_ENABLE_INSTRUCTIONS was set. Cautiously check whether its value is usable. // This function checks the context pointer and aborts with a useful // error message if the pointer is found to be NULL. if ( bli_error_checking_is_enabled() ) { err_t e_val = bli_check_initialized_gks_cntx( req_cntx ); - bli_check_error_code( e_val ); + if ( e_val != BLIS_SUCCESS ) + { + arch_not_in_build = TRUE; + arch_reset = TRUE; + req_id = actual_arch_id; + model_id = actual_model_id; + } } - // If BLIS_ARCH_TYPE (or renamed version of this environment variable) - // was set, we always use this value of req_id to set arch_id. - // However, if AOCL_ENABLE_INSTRUCTIONS was set instead, we check for - // ISA compatibility and switch to a supported option if necessary. - if ( aocl_e_i ) - { #if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) - // If AVX2 test fails here we assume either: - // 1. Config was either zen, zen2, zen3, zen4, haswell or skx, - // so there is no fallback code path, hence error checking - // above will fail. - // 2. Config was amdzen, intel64 or x86_64, and will have - // generic code path. - if ( !bli_cpuid_is_avx2fma3_supported() ) + // If AVX2 test fails here we assume either: + // 1. Config was either zen, zen2, zen3, zen4, zen5, haswell or skx, + // so there is no fallback code path, hence error checking + // above will fail. + // 2. Config was amdzen, intel64 or x86_64, and will have + // generic code path. + if ( !bli_cpuid_is_avx2fma3_supported() ) + { + switch (req_id) { - switch (req_id) - { - case BLIS_ARCH_ZEN4: - case BLIS_ARCH_ZEN3: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN: - case BLIS_ARCH_EXCAVATOR: - case BLIS_ARCH_SKX: - case BLIS_ARCH_HASWELL: - arch_reset = TRUE; - req_id = BLIS_ARCH_GENERIC; - model_id = BLIS_MODEL_DEFAULT; - continue; - break; - } + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_EXCAVATOR: + case BLIS_ARCH_SKX: + case BLIS_ARCH_HASWELL: + arch_reset = TRUE; + req_id = actual_arch_id; + model_id = actual_model_id; + break; } - // If AVX512 test fails here we assume either: - // 1. Config was either zen4 or skx, so there is - // no fallback code path, hence error checking - // above will fail. - // 2. Config was amdzen, intel64 or x86_64, and will have - // appropriate avx2 code path to try. - if ( !bli_cpuid_is_avx512_supported() ) + } + // If AVX512 test fails here we assume either: + // 1. Config was either zen5, zen4 or skx, so there is + // no fallback code path, hence error checking + // above will fail. + // 2. Config was amdzen, intel64 or x86_64, and will have + // appropriate avx2 code path to try. + if ( !bli_cpuid_is_avx512_supported() ) + { + switch (req_id) { - switch (req_id) - { - case BLIS_ARCH_ZEN4: - arch_reset = TRUE; - req_id = BLIS_ARCH_ZEN3; - model_id = BLIS_MODEL_DEFAULT; - continue; - break; - case BLIS_ARCH_SKX: - arch_reset = TRUE; - req_id = BLIS_ARCH_HASWELL; - model_id = BLIS_MODEL_DEFAULT; - continue; - break; - } + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_SKX: + arch_reset = TRUE; + req_id = actual_arch_id; + model_id = actual_model_id; + break; } - // If both tests above pass, we accept req_id choice. - test_arch = FALSE; - - // Note: Pre-AVX2 systems from AMD and Intel, and Intel KNL, - // have not been included in these tests, and thus could - // continue to give illegal instruction errors on other - // platforms, just as if BLIS_ARCH_TYPE was set to the - // same value. -#else - // Non-x86 platforms just accept value given for now. - // Similar logic to x86 if block could be implemented - // here if desired. - test_arch = FALSE; -#endif } - else + + // Note: Pre-AVX2 systems from AMD and Intel, and Intel KNL, + // have not been included in these tests, and thus could + // continue to give illegal instruction errors on other + // platforms, just as if BLIS_ARCH_TYPE was set to the + // same value. +#else + // Non-x86 platforms just accept value given for now. + // Similar logic to x86 if block could be implemented + // here if desired. + arch_reset = FALSE; +#endif + } + else + { + // BLIS_ARCH_TYPE was set. Cautiously check whether its value is usable. + + // This function checks the context pointer and aborts with a useful + // error message if the pointer is found to be NULL. + if ( bli_error_checking_is_enabled() ) { - test_arch = FALSE; + err_t e_val = bli_check_initialized_gks_cntx( req_cntx ); + bli_check_error_code( e_val ); } + // If BLIS_ARCH_TYPE (or renamed version of this environment variable) + // was set, we always use this value of req_id to set arch_id. } // Finally, we can be confident that req_id (1) is in range and (2) @@ -488,16 +511,50 @@ void bli_arch_check_id( void ) if ( bli_arch_get_logging() ) { - if ( arch_reset ) + if ( req_id == -1 && aocl_e_i) + { + // AOCL_ENABLE_INSTRUCTIONS was set to an invalid value + // normal system arch_id was used instead. + if ( model_id == BLIS_MODEL_DEFAULT ) + { + fprintf( stderr, "libblis: AOCL_ENABLE_INSTRUCTIONS env var was set to an invalid value.\n" + "libblis: Selecting system default sub-configuration '%s'.\n", + bli_arch_string( arch_id ) ); + } + else + { + fprintf( stderr, "libblis: AOCL_ENABLE_INSTRUCTIONS env var was set to an invalid value.\n" + "libblis: Selecting system default sub-configuration '%s', model '%s'.\n", + bli_arch_string( arch_id ), bli_model_string( model_id ) ); + } + } + else if ( arch_not_in_build ) { if ( orig_model_id == BLIS_MODEL_DEFAULT ) { - fprintf( stderr, "libblis: Sub-configuration '%s' is not supported on this system.\nlibblis: Switching to sub-configuration '%s'.\n", + fprintf( stderr, "libblis: Sub-configuration '%s' is not implemented in this build.\n" + "libblis: Selecting system default sub-configuration '%s'.\n", bli_arch_string( orig_arch_id ), bli_arch_string( arch_id ) ); } else { - fprintf( stderr, "libblis: Sub-configuration '%s', model '%s' is not supported on this system.\nlibblis: Switching to sub-configuration '%s', model '%s'.\n", + fprintf( stderr, "libblis: Sub-configuration '%s', model '%s' is not implemented in this build.\n" + "libblis: Selecting system default sub-configuration '%s', model '%s'.\n", + bli_arch_string( orig_arch_id ), bli_model_string( orig_model_id ), bli_arch_string( arch_id ), bli_model_string( model_id ) ); + } + } + else if ( arch_reset ) + { + if ( orig_model_id == BLIS_MODEL_DEFAULT ) + { + fprintf( stderr, "libblis: Sub-configuration '%s' is not supported on this system.\n" + "libblis: Selecting system default sub-configuration '%s'.\n", + bli_arch_string( orig_arch_id ), bli_arch_string( arch_id ) ); + } + else + { + fprintf( stderr, "libblis: Sub-configuration '%s', model '%s' is not supported on this system.\n" + "libblis: Selecting system default sub-configuration '%s', model '%s'.\n", bli_arch_string( orig_arch_id ), bli_model_string( orig_model_id ), bli_arch_string( arch_id ), bli_model_string( model_id ) ); } } @@ -569,6 +626,7 @@ static char* config_name[ BLIS_NUM_ARCHS ] = "sandybridge", "penryn", + "zen5", "zen4", "zen3", "zen2", @@ -578,11 +636,12 @@ static char* config_name[ BLIS_NUM_ARCHS ] = "piledriver", "bulldozer", + "armsve", + "a64fx", + "firestorm", "thunderx2", "cortexa57", "cortexa53", - "armsve", - "a64fx", "cortexa15", "cortexa9", @@ -609,6 +668,9 @@ static char* model_name[ BLIS_NUM_MODELS ] = "default", + "Turin", + "Turin Dense", + "Genoa", "Bergamo", "Genoa-X", diff --git a/frame/base/bli_auxinfo.h b/frame/base/bli_auxinfo.h index 4d5909f33f..68b6cc7cd6 100644 --- a/frame/base/bli_auxinfo.h +++ b/frame/base/bli_auxinfo.h @@ -74,13 +74,6 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai ) return ai->ps_b; } -#if 0 -BLIS_INLINE inc_t bli_auxinfo_dt_on_output( auxinfo_t* ai ) -{ - return ai->dt_on_output; -} -#endif - // auxinfo_t field modification @@ -125,12 +118,5 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai ) ai->ps_b = ps; } -#if 0 -BLIS_INLINE void bli_auxinfo_set_dt_on_output( num_t dt_on_output, auxinfo_t* ai ) -{ - ai->dt_on_output = dt_on_output; -} -#endif - #endif diff --git a/frame/base/bli_check.c b/frame/base/bli_check.c index a7c3d194ba..35d2be082c 100644 --- a/frame/base/bli_check.c +++ b/frame/base/bli_check.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -898,6 +898,14 @@ err_t bli_check_valid_model_id( arch_t arch_id, model_t model_id ) // Model ranges are specified in bli_type_defs.h err_t e_val = BLIS_INVALID_MODEL_ID; + if ( arch_id == BLIS_ARCH_ZEN5 ) + { + if ( ( gint_t )model_id >= BLIS_MODEL_TURIN && + ( gint_t )model_id <= BLIS_MODEL_TURIN_DENSE ) + { + e_val = BLIS_SUCCESS; + } + } if ( arch_id == BLIS_ARCH_ZEN4 ) { if ( ( gint_t )model_id >= BLIS_MODEL_GENOA && diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 8c2651c4cb..1cc5a2e13e 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -224,12 +224,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) double msclr = msclrs[ i ]; blksz_t* blksz = blkszs[ i ]; - // NOTE: This is a bug! We need to grab the actual blocksize - // multiple, which is not at blkszs[i], but rather somewhere else - // in the array. In order to fix this, you probably need to store - // the contents of blkszs (and all the other arrays) by bs_id - // rather than i in the first loop. - blksz_t* bmult = blkszs[ i ]; blksz_t* cntx_blksz = &cntx_blkszs[ bs_id ]; @@ -248,20 +242,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // blocksize object. bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_SCOMPLEX, cntx_blksz ); bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_def_to( BLIS_FLOAT, bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_def_to( BLIS_DOUBLE, bmult, BLIS_DCOMPLEX, cntx_blksz ); - } } // Similarly, if the maximum blocksize scalar is non-unit, we need @@ -272,20 +252,6 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // blocksize object. bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_SCOMPLEX, cntx_blksz ); bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_DCOMPLEX, cntx_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_max_to( BLIS_FLOAT, bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_max_to( BLIS_DOUBLE, bmult, BLIS_DCOMPLEX, cntx_blksz ); - } } // Copy the blocksize multiple id into the context. @@ -323,13 +289,14 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // ----------------------------------------------------------------------------- -void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) +void bli_cntx_set_ind_blkszs( ind_t method, num_t dt, dim_t n_bs, ... ) { /* Example prototypes: void bli_gks_cntx_set_ind_blkszs ( ind_t method != BLIS_NAT, + num_t dt, dim_t n_bs, bszid_t bs0_id, dim_t def_scalr0, dim_t max_scalr0, bszid_t bs1_id, dim_t def_scalr1, dim_t max_scalr1, @@ -346,6 +313,9 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) dim_t i; err_t r_val; + // Project the given datatype to the real domain. This will be used later on. + num_t dt_real = bli_dt_proj_to_real( dt ); + // Return early if called with BLIS_NAT. if ( method == BLIS_NAT ) return; @@ -418,77 +388,35 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) //blksz_t* cntx_blksz = &cntx_blkszs[ bs_id ]; - // Query the blocksize multiple's blocksize id. - bszid_t bm_id = bli_cntx_get_bmult_id( bs_id, cntx ); - // Query the context for the blksz_t object assoicated with the // current blocksize id, and also query the object corresponding // to the blocksize multiple. blksz_t* cntx_blksz = bli_cntx_get_blksz( bs_id, cntx ); - blksz_t* cntx_bmult = bli_cntx_get_bmult( bs_id, cntx ); blksz_t* cntx_trsm_blksz = bli_cntx_get_trsm_blksz( bs_id, cntx ); - // Copy the real domain values of the blksz_t object into the - // the complex domain slots of the same object. - bli_blksz_copy_dt( BLIS_FLOAT, cntx_blksz, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_copy_dt( BLIS_DOUBLE, cntx_blksz, BLIS_DCOMPLEX, cntx_blksz ); - bli_blksz_copy_dt( BLIS_FLOAT, cntx_blksz, BLIS_SCOMPLEX, cntx_trsm_blksz); - bli_blksz_copy_dt( BLIS_DOUBLE, cntx_blksz, BLIS_DCOMPLEX, cntx_trsm_blksz); + // Copy the real domain value of the blksz_t object into the + // corresponding complex domain slot of the same object. + bli_blksz_copy_dt( dt_real, cntx_blksz, dt, cntx_blksz ); + bli_blksz_copy_dt( dt_real, cntx_blksz, dt, cntx_trsm_blksz ); // If the default blocksize scalar is non-unit, we need to scale // the complex domain default blocksizes. if ( dsclr != 1.0 ) { - // Scale the complex domain default blocksize values in the - // blocksize object. - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_DCOMPLEX, cntx_blksz ); - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_SCOMPLEX, cntx_trsm_blksz); - bli_blksz_scale_def( 1, ( dim_t )dsclr, BLIS_DCOMPLEX, cntx_trsm_blksz); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_def_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_def_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_blksz ); - bli_blksz_reduce_def_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_trsm_blksz ); - bli_blksz_reduce_def_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_trsm_blksz ); - } + // Scale the default blocksize value corresponding to the given + // datatype. + bli_blksz_scale_def( 1, ( dim_t )dsclr, dt, cntx_blksz ); + bli_blksz_scale_def( 1, ( dim_t )dsclr, dt, cntx_trsm_blksz ); } // Similarly, if the maximum blocksize scalar is non-unit, we need // to scale the complex domain maximum blocksizes. if ( msclr != 1.0 ) { - // Scale the complex domain maximum blocksize values in the - // blocksize object. - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_DCOMPLEX, cntx_blksz ); - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_SCOMPLEX, cntx_trsm_blksz ); - bli_blksz_scale_max( 1, ( dim_t )msclr, BLIS_DCOMPLEX, cntx_trsm_blksz ); - - // Perform rounding to ensure the newly scaled values are still - // multiples of their register blocksize multiples. But only - // perform this rounding when the blocksize id is not equal to - // the blocksize multiple id (ie: we don't round down scaled - // register blocksizes since they are their own multiples). - // Also, we skip the rounding for 1m since it should never need - // such rounding. - if ( bs_id != bm_id && method != BLIS_1M ) - { - // Round the newly-scaled blocksizes down to their multiple. - bli_blksz_reduce_max_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_blksz ); - bli_blksz_reduce_max_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_blksz ); - bli_blksz_reduce_max_to( BLIS_FLOAT, cntx_bmult, BLIS_SCOMPLEX, cntx_trsm_blksz ); - bli_blksz_reduce_max_to( BLIS_DOUBLE, cntx_bmult, BLIS_DCOMPLEX, cntx_trsm_blksz ); - } + // Scale the maximum blocksize value corresponding to the given + // datatype. + bli_blksz_scale_max( 1, ( dim_t )msclr, dt, cntx_blksz ); + bli_blksz_scale_max( 1, ( dim_t )msclr, dt, cntx_trsm_blksz ); } } } diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index 8de023a2b2..d4e38fef99 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -63,9 +63,6 @@ typedef struct cntx_s func_t* unpackm_kers; ind_t method; - pack_t schema_a; - pack_t schema_b; - pack_t schema_c; } cntx_t; */ @@ -156,18 +153,6 @@ BLIS_INLINE ind_t bli_cntx_method( cntx_t* cntx ) { return cntx->method; } -BLIS_INLINE pack_t bli_cntx_schema_a_block( cntx_t* cntx ) -{ - return cntx->schema_a_block; -} -BLIS_INLINE pack_t bli_cntx_schema_b_panel( cntx_t* cntx ) -{ - return cntx->schema_b_panel; -} -BLIS_INLINE pack_t bli_cntx_schema_c_panel( cntx_t* cntx ) -{ - return cntx->schema_c_panel; -} // ----------------------------------------------------------------------------- @@ -179,23 +164,6 @@ BLIS_INLINE void bli_cntx_set_method( ind_t method, cntx_t* cntx ) { cntx->method = method; } -BLIS_INLINE void bli_cntx_set_schema_a_block( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_a_block = schema; -} -BLIS_INLINE void bli_cntx_set_schema_b_panel( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_b_panel = schema; -} -BLIS_INLINE void bli_cntx_set_schema_c_panel( pack_t schema, cntx_t* cntx ) -{ - cntx->schema_c_panel = schema; -} -BLIS_INLINE void bli_cntx_set_schema_ab_blockpanel( pack_t sa, pack_t sb, cntx_t* cntx ) -{ - bli_cntx_set_schema_a_block( sa, cntx ); - bli_cntx_set_schema_b_panel( sb, cntx ); -} // ----------------------------------------------------------------------------- @@ -942,7 +910,7 @@ BLIS_EXPORT_BLIS void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ); BLIS_EXPORT_BLIS void bli_cntx_set_trsm_blkszs( dim_t n_bs, ... ); -BLIS_EXPORT_BLIS void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_ind_blkszs( ind_t method, num_t dt, dim_t n_bs, ... ); BLIS_EXPORT_BLIS void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ); BLIS_EXPORT_BLIS void bli_cntx_set_l3_vir_ukrs( dim_t n_ukrs, ... ); diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index d54c6a8bb3..d89e7e34cd 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without @@ -94,6 +94,9 @@ static bool is_avx512_supported = FALSE; static bool is_avx512vnni_supported = FALSE; static bool is_avx512bf16_supported = FALSE; +// Variable to represent FP/SIMD execution datapath width. +static uint32_t bli_fp_datapath = -1; + // Variables to store the cache sizes (in KB). L3 size is shared by all // logical processors in the package (i.e. per socket). static uint32_t bli_l1d_cache_size = -1; @@ -118,6 +121,9 @@ arch_t bli_cpuid_query_id( void ) bli_cpuid_check_avx512vnni_support( family, model, features ); bli_cpuid_check_avx512bf16_support( family, model, features ); + // Check FP/SIMD execution datapath + bli_cpuid_check_datapath( vendor, features ); + // Find out cache sizes and set in static variables. // Currently only enabled for VENDOR_AMD. bli_cpuid_check_cache( vendor ); @@ -134,6 +140,9 @@ arch_t bli_cpuid_query_id( void ) printf( "AVX512 VNNI = %d\n", is_avx512vnni_supported ); printf( "AVX512 BF16 = %d\n", is_avx512bf16_supported ); + const char* datapath_names[] = {"UNSET", "FP128", "INVALID", "FP256", "FP512"}; + printf( "FP/SIMD datapath = %d (%s)\n", bli_fp_datapath, datapath_names[bli_fp_datapath+1] ); + printf( "Cache Information:\n" ); printf( "L1I size = %u KB\n",bli_l1i_cache_size ); printf( "L1D size = %u KB\n",bli_l1d_cache_size ); @@ -185,13 +194,25 @@ arch_t bli_cpuid_query_id( void ) } else if ( vendor == VENDOR_AMD ) { - // Check for each AMD configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. +#ifdef BLIS_CONFIG_ZEN5 + if ( bli_cpuid_is_zen5( family, model, features ) ) + return BLIS_ARCH_ZEN5; +#endif #ifdef BLIS_CONFIG_ZEN4 if ( bli_cpuid_is_zen4( family, model, features ) ) return BLIS_ARCH_ZEN4; +#endif +#ifdef BLIS_CONFIG_ZEN5 // Fallback test for future AMD processors + // Assume zen5 (if available) is preferable to zen4. + if ( is_avx512_supported ) + return BLIS_ARCH_ZEN5; +#endif +#ifdef BLIS_CONFIG_ZEN4 + // Fallback test for future AMD processors + // Use zen4 if zen5 is not available. if ( is_avx512_supported ) return BLIS_ARCH_ZEN4; #endif @@ -207,6 +228,12 @@ arch_t bli_cpuid_query_id( void ) if ( bli_cpuid_is_zen( family, model, features ) ) return BLIS_ARCH_ZEN; #endif +#ifdef BLIS_CONFIG_ZEN3 + // Fallback test for future AMD processors + // Use zen3 if AVX512 support is not available but AVX2 is. + if ( is_avx2fma3_supported ) + return BLIS_ARCH_ZEN3; +#endif #ifdef BLIS_CONFIG_EXCAVATOR if ( bli_cpuid_is_excavator( family, model, features ) ) return BLIS_ARCH_EXCAVATOR; @@ -240,6 +267,22 @@ model_t bli_cpuid_query_model_id( arch_t arch_id ) // Set default for architectures where separate models haven't been defined. model_t cpuid_model = BLIS_MODEL_DEFAULT; +#ifdef BLIS_CONFIG_ZEN5 + if (arch_id == BLIS_ARCH_ZEN5) + { + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. The return value encodes the + // vendor. + + uint32_t __attribute__ ((unused)) vendor; + uint32_t family, model, features; + + vendor = bli_cpuid_query( &family, &model, &features ); + + // Check CPU model. + cpuid_model = bli_cpuid_get_zen5_cpuid_model( family, model, features ); + } +#endif #ifdef BLIS_CONFIG_ZEN4 if (arch_id == BLIS_ARCH_ZEN4) { @@ -386,6 +429,60 @@ bool bli_cpuid_is_penryn } // ----------------------------------------------------------------------------- +bool bli_cpuid_is_zen5 + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Check for expected CPU features. + const uint32_t expected = FEATURE_SSE3 | + FEATURE_SSSE3 | + FEATURE_SSE41 | + FEATURE_SSE42 | + FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512CD | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI | + FEATURE_AVX512BF16 | + FEATURE_MOVDIRI | + FEATURE_MOVDIR64B | + FEATURE_AVX512VP2INTERSECT | + FEATURE_AVXVNNI; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + // For zen5 the family id is 0x1A + if ( family != 0x1A ) return FALSE; + + return TRUE; +} +model_t bli_cpuid_get_zen5_cpuid_model + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Look at model of CPU and set cpuid_model appropriately. + // For Zen5, the default is Turin. + model_t cpuid_model = BLIS_MODEL_TURIN; + if ( family == 0x1A ) + { + if ( 0x10 <= model && model <= 0x1f ) // Turin Dense + { + cpuid_model = BLIS_MODEL_TURIN_DENSE; + } + } + return cpuid_model; +} + bool bli_cpuid_is_zen4 ( uint32_t family, @@ -438,6 +535,14 @@ model_t bli_cpuid_get_zen4_cpuid_model { cpuid_model = BLIS_MODEL_BERGAMO; } + else + { + uint32_t l3_cache_size = bli_cpuid_query_l3_cache_size(); + if ( l3_cache_size > 393216 ) + { + cpuid_model = BLIS_MODEL_GENOA_X; + } + } } return cpuid_model; } @@ -824,6 +929,12 @@ bool bli_cpuid_is_avx512bf16_supported( void ) return is_avx512bf16_supported; } +uint32_t bli_cpuid_query_fp_datapath( void ) +{ + bli_cpuid_query_id_once(); + return bli_fp_datapath; +} + uint32_t bli_cpuid_query_l1d_cache_size( void ) { bli_cpuid_query_id_once(); @@ -854,9 +965,6 @@ arch_t bli_cpuid_query_id( void ) { uint32_t vendor, model, part, features; - // Call the CPUID instruction and parse its results into a model id, - // part id, and a feature bit field. The return value encodes the - // vendor. vendor = bli_cpuid_query( &model, &part, &features ); #if 0 @@ -872,24 +980,9 @@ arch_t bli_cpuid_query_id( void ) { if ( model == MODEL_ARMV8 ) { + return part; // Check for each ARMv8 configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. -#ifdef BLIS_CONFIG_ARMSVE - if ( bli_cpuid_is_armsve( model, part, features ) ) - return BLIS_ARCH_ARMSVE; -#endif -#ifdef BLIS_CONFIG_A64FX - if ( bli_cpuid_is_a64fx( model, part, features ) ) - return BLIS_ARCH_A64FX; -#endif -#ifdef BLIS_CONFIG_THUNDERX2 - if ( bli_cpuid_is_thunderx2( model, part, features ) ) - return BLIS_ARCH_THUNDERX2; -#endif -#ifdef BLIS_CONFIG_CORTEXA57 - if ( bli_cpuid_is_cortexa57( model, part, features ) ) - return BLIS_ARCH_CORTEXA57; -#endif // If none of the other sub-configurations were detected, return // the 'generic' arch_t id value. return BLIS_ARCH_GENERIC; @@ -925,81 +1018,6 @@ model_t bli_cpuid_query_model_id( arch_t arch_id ) return BLIS_MODEL_DEFAULT; } -bool bli_cpuid_is_thunderx2 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool bli_cpuid_is_cortexa57 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool bli_cpuid_is_cortexa53 - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_NEON; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool bli_cpuid_is_armsve - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_SVE; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - -bool bli_cpuid_is_a64fx - ( - uint32_t family, - uint32_t model, - uint32_t features - ) -{ - // Check for expected CPU features. - const uint32_t expected = FEATURE_SVE; - - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; -} - bool bli_cpuid_is_cortexa15 ( uint32_t family, @@ -1010,9 +1028,7 @@ bool bli_cpuid_is_cortexa15 // Check for expected CPU features. const uint32_t expected = FEATURE_NEON; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; + return bli_cpuid_has_features( features, expected ) && model == 0xc0f; } bool bli_cpuid_is_cortexa9 @@ -1025,9 +1041,7 @@ bool bli_cpuid_is_cortexa9 // Check for expected CPU features. const uint32_t expected = FEATURE_NEON; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; - - return TRUE; + return bli_cpuid_has_features( features, expected ) && model == 0xc09; } #else @@ -1060,7 +1074,7 @@ model_t bli_cpuid_query_model_id( arch_t arch_id ) Copyright (C) 2017, The University of Texas at Austin Copyright (C) 2017, Devin Matthews - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1092,29 +1106,36 @@ model_t bli_cpuid_query_model_id( arch_t arch_id ) enum { - // input register(s) output register - FEATURE_MASK_SSE3 = (1u<< 0), // cpuid[eax=1] :ecx[0] - FEATURE_MASK_SSSE3 = (1u<< 9), // cpuid[eax=1] :ecx[9] - FEATURE_MASK_SSE41 = (1u<<19), // cpuid[eax=1] :ecx[19] - FEATURE_MASK_SSE42 = (1u<<20), // cpuid[eax=1] :ecx[20] - FEATURE_MASK_AVX = (1u<<28), // cpuid[eax=1] :ecx[28] - FEATURE_MASK_AVX2 = (1u<< 5), // cpuid[eax=7,ecx=0] :ebx[5] - FEATURE_MASK_FMA3 = (1u<<12), // cpuid[eax=1] :ecx[12] - FEATURE_MASK_FMA4 = (1u<<16), // cpuid[eax=0x80000001]:ecx[16] - FEATURE_MASK_AVX512F = (1u<<16), // cpuid[eax=7,ecx=0] :ebx[16] - FEATURE_MASK_AVX512DQ = (1u<<17), // cpuid[eax=7,ecx=0] :ebx[17] - FEATURE_MASK_AVX512PF = (1u<<26), // cpuid[eax=7,ecx=0] :ebx[26] - FEATURE_MASK_AVX512ER = (1u<<27), // cpuid[eax=7,ecx=0] :ebx[27] - FEATURE_MASK_AVX512CD = (1u<<28), // cpuid[eax=7,ecx=0] :ebx[28] - FEATURE_MASK_AVX512BW = (1u<<30), // cpuid[eax=7,ecx=0] :ebx[30] - FEATURE_MASK_AVX512VL = (1u<<31), // cpuid[eax=7,ecx=0] :ebx[31] - FEATURE_MASK_AVX512VNNI = (1u<<11), // cpuid[eax=7,ecx=0] :ecx[11] - FEATURE_MASK_AVX512BF16 = (1u<< 5), // cpuid[eax=7,ecx=1] :eax[5] - FEATURE_MASK_XGETBV = (1u<<26)| - (1u<<27), // cpuid[eax=1] :ecx[27:26] - XGETBV_MASK_XMM = 0x02u, // xcr0[1] - XGETBV_MASK_YMM = 0x04u, // xcr0[2] - XGETBV_MASK_ZMM = 0xe0u // xcr0[7:5] + // input register(s) output register + FEATURE_MASK_SSE3 = (1u<< 0), // cpuid[eax=1] :ecx[0] + FEATURE_MASK_SSSE3 = (1u<< 9), // cpuid[eax=1] :ecx[9] + FEATURE_MASK_SSE41 = (1u<<19), // cpuid[eax=1] :ecx[19] + FEATURE_MASK_SSE42 = (1u<<20), // cpuid[eax=1] :ecx[20] + FEATURE_MASK_AVX = (1u<<28), // cpuid[eax=1] :ecx[28] + FEATURE_MASK_AVX2 = (1u<< 5), // cpuid[eax=7,ecx=0] :ebx[5] + FEATURE_MASK_FMA3 = (1u<<12), // cpuid[eax=1] :ecx[12] + FEATURE_MASK_FMA4 = (1u<<16), // cpuid[eax=0x80000001] :ecx[16] + FEATURE_MASK_AVX512F = (1u<<16), // cpuid[eax=7,ecx=0] :ebx[16] + FEATURE_MASK_AVX512DQ = (1u<<17), // cpuid[eax=7,ecx=0] :ebx[17] + FEATURE_MASK_AVX512PF = (1u<<26), // cpuid[eax=7,ecx=0] :ebx[26] + FEATURE_MASK_AVX512ER = (1u<<27), // cpuid[eax=7,ecx=0] :ebx[27] + FEATURE_MASK_AVX512CD = (1u<<28), // cpuid[eax=7,ecx=0] :ebx[28] + FEATURE_MASK_AVX512BW = (1u<<30), // cpuid[eax=7,ecx=0] :ebx[30] + FEATURE_MASK_AVX512VL = (1u<<31), // cpuid[eax=7,ecx=0] :ebx[31] + FEATURE_MASK_AVX512VNNI = (1u<<11), // cpuid[eax=7,ecx=0] :ecx[11] + FEATURE_MASK_MOVDIRI = (1u<<27), // cpuid[eax=7,ecx=0] :ecx[27] + FEATURE_MASK_MOVDIR64B = (1u<<28), // cpuid[eax=7,ecx=0] :ecx[28] + FEATURE_MASK_AVX512VP2INTERSECT = (1u<<8), // cpuid[eax=7,ecx=0] :edx[8] + FEATURE_MASK_AVXVNNI = (1u<< 4), // cpuid[eax=7,ecx=1] :eax[4] + FEATURE_MASK_AVX512BF16 = (1u<< 5), // cpuid[eax=7,ecx=1] :eax[5] + FEATURE_MASK_XGETBV = (1u<<26)| + (1u<<27), // cpuid[eax=1] :ecx[27:26] + XGETBV_MASK_XMM = 0x02u, // xcr0[1] + XGETBV_MASK_YMM = 0x04u, // xcr0[2] + XGETBV_MASK_ZMM = 0xe0u, // xcr0[7:5] + FEATURE_MASK_DATAPATH_FP128 = (1u<<0), // cpuid[eax=0x8000001A] :eax[0] + FEATURE_MASK_DATAPATH_FP256 = (1u<<2), // cpuid[eax=0x8000001A] :eax[2] + FEATURE_MASK_DATAPATH_FP512 = (1u<<3) // cpuid[eax=0x8000001A] :eax[3] }; @@ -1178,6 +1199,10 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512VL ) ) *features |= FEATURE_AVX512VL; if ( bli_cpuid_has_features( ecx, FEATURE_MASK_AVX512VNNI ) ) *features |= FEATURE_AVX512VNNI; + if ( bli_cpuid_has_features( ecx, FEATURE_MASK_MOVDIRI ) ) *features |= FEATURE_MOVDIRI; + if ( bli_cpuid_has_features( ecx, FEATURE_MASK_MOVDIR64B ) ) *features |= FEATURE_MOVDIR64B; + + if ( bli_cpuid_has_features( edx, FEATURE_MASK_AVX512VP2INTERSECT ) ) *features |= FEATURE_AVX512VP2INTERSECT; // This is actually a macro that modifies the last four operands, // hence why they are not passed by address. @@ -1186,8 +1211,8 @@ uint32_t bli_cpuid_query // 5th feature bit of the returned value __cpuid_count( 7, 1, eax, ebx, ecx, edx ); + if ( bli_cpuid_has_features( eax, FEATURE_MASK_AVXVNNI ) ) *features |= FEATURE_AVXVNNI; if ( bli_cpuid_has_features( eax, FEATURE_MASK_AVX512BF16 ) ) *features |= FEATURE_AVX512BF16; - } // Check extended processor info / features bits for AMD-specific features. @@ -1205,6 +1230,17 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ecx, FEATURE_MASK_FMA4 ) ) *features |= FEATURE_FMA4; } + if ( cpuid_max_ext >= 0x8000001Au ) + { + // This is actually a macro that modifies the last four operands, + // hence why they are not passed by address. + // This returns extended feature flags in EAX. + __cpuid( 0x8000001A, eax, ebx, ecx, edx ); + + if ( bli_cpuid_has_features( eax, FEATURE_MASK_DATAPATH_FP128 ) ) *features |= FEATURE_DATAPATH_FP128; + if ( bli_cpuid_has_features( eax, FEATURE_MASK_DATAPATH_FP256 ) ) *features |= FEATURE_DATAPATH_FP256; + if ( bli_cpuid_has_features( eax, FEATURE_MASK_DATAPATH_FP512 ) ) *features |= FEATURE_DATAPATH_FP512; + } // Unconditionally check processor info / features bits. { @@ -1306,8 +1342,8 @@ uint32_t bli_cpuid_query // only if the xcr[7:5] bits are set. If they are not set, then // clear all feature bits related to AVX-512. if ( !bli_cpuid_has_features( eax, XGETBV_MASK_XMM | - XGETBV_MASK_YMM | - XGETBV_MASK_ZMM ) ) + XGETBV_MASK_YMM | + XGETBV_MASK_ZMM ) ) { *features &= ~( FEATURE_AVX512F | FEATURE_AVX512DQ | @@ -1322,7 +1358,7 @@ uint32_t bli_cpuid_query // only if the xcr[2] bit is set. If it is not set, then // clear all feature bits related to AVX. if ( !bli_cpuid_has_features( eax, XGETBV_MASK_XMM | - XGETBV_MASK_YMM ) ) + XGETBV_MASK_YMM ) ) { *features &= ~( FEATURE_AVX | FEATURE_AVX2 | @@ -1375,6 +1411,34 @@ uint32_t bli_cpuid_query return VENDOR_UNKNOWN; } +void bli_cpuid_check_datapath( + uint32_t vendor, + uint32_t features ) +{ + if ( vendor == VENDOR_AMD ) + { + uint32_t expected; + expected = FEATURE_DATAPATH_FP512; + if ( bli_cpuid_has_features( features, expected ) ) + { + bli_fp_datapath = DATAPATH_FP512; + return; + } + expected = FEATURE_DATAPATH_FP256; + if ( bli_cpuid_has_features( features, expected ) ) + { + bli_fp_datapath = DATAPATH_FP256; + return; + } + expected = FEATURE_DATAPATH_FP128; + if ( bli_cpuid_has_features( features, expected ) ) + { + bli_fp_datapath = DATAPATH_FP128; + return; + } + } +} + void bli_cpuid_check_cache( uint32_t vendor ) { if ( vendor == VENDOR_AMD ) @@ -1494,7 +1558,243 @@ int vpu_count( void ) } } -#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) +#elif defined(__aarch64__) + +#ifdef __linux__ +// This is adapted from OpenBLAS. See +// https://www.kernel.org/doc/html/latest/arm64/cpu-feature-registers.html +// for the mechanism, but not the magic numbers. + +// Fixme: Could these be missing in older Linux? +#include +#include + +#ifndef HWCAP_CPUID +#define HWCAP_CPUID (1 << 11) +#endif +/* From https://www.kernel.org/doc/html/latest/arm64/sve.html and the + aarch64 hwcap.h */ +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif +/* Maybe also for AT_HWCAP2 +#define HWCAP2_SVE2(1 << 1) +et al +) */ + +#endif //__linux__ + +#ifdef __APPLE__ +#include +// #include +#endif + +static uint32_t get_coretype + ( + uint32_t* features + ) +{ + int implementer = 0x00, part = 0x000; + *features = FEATURE_NEON; + +#ifdef __linux__ + if ( getauxval( AT_HWCAP ) & HWCAP_CPUID ) + { + // Also available from + // /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 + // and split out in /proc/cpuinfo (with a tab before the colon): + // CPU part : 0x0a1 + + uint64_t midr_el1; + __asm("mrs %0, MIDR_EL1" : "=r" (midr_el1)); + /* + * MIDR_EL1 + * + * 31 24 23 20 19 16 15 4 3 0 + * ----------------------------------------------------------------- + * | Implementer | Variant | Architecture | Part Number | Revision | + * ----------------------------------------------------------------- + */ + implementer = (midr_el1 >> 24) & 0xFF; + part = (midr_el1 >> 4) & 0xFFF; + } + + bool has_sve = getauxval( AT_HWCAP ) & HWCAP_SVE; + if (has_sve) + *features |= FEATURE_SVE; +#endif //__linux__ + +#ifdef __APPLE__ + // Better values could be obtained from sysctlbyname() + implementer = 0x61; //Apple + part = 0x023; //Firestorm +#endif //__APPLE__ + + // From Linux arch/arm64/include/asm/cputype.h + // ARM_CPU_IMP_ARM 0x41 + // ARM_CPU_IMP_APM 0x50 + // ARM_CPU_IMP_CAVIUM 0x43 + // ARM_CPU_IMP_BRCM 0x42 + // ARM_CPU_IMP_QCOM 0x51 + // ARM_CPU_IMP_NVIDIA 0x4E + // ARM_CPU_IMP_FUJITSU 0x46 + // ARM_CPU_IMP_HISI 0x48 + // ARM_CPU_IMP_APPLE 0x61 + // + // ARM_CPU_PART_AEM_V8 0xD0F + // ARM_CPU_PART_FOUNDATION 0xD00 + // ARM_CPU_PART_CORTEX_A57 0xD07 + // ARM_CPU_PART_CORTEX_A72 0xD08 + // ARM_CPU_PART_CORTEX_A53 0xD03 + // ARM_CPU_PART_CORTEX_A73 0xD09 + // ARM_CPU_PART_CORTEX_A75 0xD0A + // ARM_CPU_PART_CORTEX_A35 0xD04 + // ARM_CPU_PART_CORTEX_A55 0xD05 + // ARM_CPU_PART_CORTEX_A76 0xD0B + // ARM_CPU_PART_NEOVERSE_N1 0xD0C + // ARM_CPU_PART_CORTEX_A77 0xD0D + // from GCC: + // ARM_CPU_PART_CORTEX_A78 0xd41 + // ARM_CPU_PART_CORTEX_X1 0xd44 + // ARM_CPU_PART_CORTEX_V1 0xd40 + // ARM_CPU_PART_CORTEX_N2 0xd49 + // ARM_CPU_PART_CORTEX_R82 0xd15 + // + // APM_CPU_PART_POTENZA 0x000 + // + // CAVIUM_CPU_PART_THUNDERX 0x0A1 + // CAVIUM_CPU_PART_THUNDERX_81XX 0x0A2 + // CAVIUM_CPU_PART_THUNDERX_83XX 0x0A3 + // CAVIUM_CPU_PART_THUNDERX2 0x0AF + // CAVIUM_CPU_PART_THUNDERX3 0x0B8 // taken from OpenBLAS + // + // BRCM_CPU_PART_BRAHMA_B53 0x100 + // BRCM_CPU_PART_VULCAN 0x516 + // + // QCOM_CPU_PART_FALKOR_V1 0x800 + // QCOM_CPU_PART_FALKOR 0xC00 + // QCOM_CPU_PART_KRYO 0x200 + // QCOM_CPU_PART_KRYO_3XX_SILVER 0x803 + // QCOM_CPU_PART_KRYO_4XX_GOLD 0x804 + // QCOM_CPU_PART_KRYO_4XX_SILVER 0x805 + // + // NVIDIA_CPU_PART_DENVER 0x003 + // NVIDIA_CPU_PART_CARMEL 0x004 + // + // FUJITSU_CPU_PART_A64FX 0x001 + // + // HISI_CPU_PART_TSV110 0xD01 + + // APPLE_CPU_PART_M1_ICESTORM 0x022 + // APPLE_CPU_PART_M1_FIRESTORM 0x023 + + // Fixme: After merging the vpu_count branch we could report the + // part here with bli_dolog. + switch(implementer) + { + case 0x41: // ARM + switch (part) + { +#ifdef BLIS_CONFIG_CORTEXA57 + case 0xd07: // Cortex A57 + return BLIS_ARCH_CORTEXA57; +#endif +#ifdef BLIS_CONFIG_CORTEXA53 + case 0xd03: // Cortex A53 + return BLIS_ARCH_CORTEXA53; +#endif +#ifdef BLIS_CONFIG_THUNDERX2 + case 0xd0c: // Neoverse N1 (and Graviton G2?) + return BLIS_ARCH_THUNDERX2; //placeholder for N1 +#endif + } + break; + case 0x42: // Broadcom + switch (part) + { +#ifdef BLIS_CONFIG_THUNDERX2 + case 0x516: // Vulcan + return BLIS_ARCH_THUNDERX2; +#endif + } + break; + case 0x43: // Cavium + switch (part) + { +#ifdef BLIS_CONFIG_THUNDERX2 + case 0x0af: // ThunderX2 + case 0x0b8: // ThunderX3 + return BLIS_ARCH_THUNDERX2; +#endif + } + break; + case 0x46: // Fujitsu + switch (part) + { +#ifdef BLIS_CONFIG_A64FX + case 0x001: // A64FX + return BLIS_ARCH_A64FX; +#endif + } + break; + case 0x61: // Apple + switch (part) + { +#ifdef BLIS_CONFIG_FIRESTORM + case 0x022: // Icestorm (M1.LITTLE) + case 0x023: // Firestorm (M1.big) + return BLIS_ARCH_FIRESTORM; +#endif + } + break; + } + +#ifdef BLIS_CONFIG_ARMSVE + if (has_sve) + return BLIS_ARCH_ARMSVE; +#endif + +// Can't use #if defined(...) here because of parsing done for autoconfiguration +#ifdef BLIS_CONFIG_CORTEXA57 + return BLIS_ARCH_CORTEXA57; +#else +#ifdef BLIS_CONFIG_CORTEXA53 + return BLIS_ARCH_CORTEXA53; +#else + return BLIS_ARCH_GENERIC; +#endif +#endif +} + +uint32_t bli_cpuid_query + ( + uint32_t* model, + uint32_t* part, + uint32_t* features + ) +{ + *model = MODEL_ARMV8; + *part = get_coretype(features); + + return VENDOR_ARM; +} + +#elif defined(__arm__) || defined(_M_ARM) + +/* + I can't easily find documentation to do this as for aarch64, though + it presumably could be unearthed from Linux code. However, on + Linux 5.2 (and Androids's 3.4), /proc/cpuinfo has this sort of + thing, used below: + + CPU implementer : 0x41 + CPU architecture: 7 + CPU variant : 0x3 + CPU part : 0xc09 + + The complication for family selection is that Neon is optional for + CortexA9, for instance. That's tested in bli_cpuid_is_cortexa9. + */ #define TEMP_BUFFER_SIZE 200 diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 5b52297589..26215f9cac 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,6 +55,8 @@ arch_t bli_cpuid_query_id( void ); model_t bli_cpuid_query_model_id( arch_t id ); +uint32_t bli_cpuid_query_fp_datapath( void ); + uint32_t bli_cpuid_query_l1d_cache_size( void ); uint32_t bli_cpuid_query_l1i_cache_size( void ); uint32_t bli_cpuid_query_l2_cache_size( void ); @@ -68,6 +70,7 @@ bool bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t feature bool bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); // AMD +bool bli_cpuid_is_zen5( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen4( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_avx512_fallback( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen3( uint32_t family, uint32_t model, uint32_t features ); @@ -78,6 +81,7 @@ bool bli_cpuid_is_steamroller( uint32_t family, uint32_t model, uint32_t feature bool bli_cpuid_is_piledriver( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_bulldozer( uint32_t family, uint32_t model, uint32_t features ); +model_t bli_cpuid_get_zen5_cpuid_model( uint32_t family, uint32_t model, uint32_t features ); model_t bli_cpuid_get_zen4_cpuid_model( uint32_t family, uint32_t model, uint32_t features ); model_t bli_cpuid_get_zen3_cpuid_model( uint32_t family, uint32_t model, uint32_t features ); @@ -92,6 +96,8 @@ bool bli_cpuid_is_cortexa9( uint32_t model, uint32_t part, uint32_t features ); uint32_t bli_cpuid_query( uint32_t* family, uint32_t* model, uint32_t* features ); +void bli_cpuid_check_datapath( uint32_t vendor, uint32_t features ); + void bli_cpuid_check_cache( uint32_t vendor ); // ----------------------------------------------------------------------------- @@ -167,23 +173,41 @@ enum }; enum { - FEATURE_SSE3 = 0x0001, - FEATURE_SSSE3 = 0x0002, - FEATURE_SSE41 = 0x0004, - FEATURE_SSE42 = 0x0008, - FEATURE_AVX = 0x0010, - FEATURE_AVX2 = 0x0020, - FEATURE_FMA3 = 0x0040, - FEATURE_FMA4 = 0x0080, - FEATURE_AVX512F = 0x0100, - FEATURE_AVX512DQ = 0x0200, - FEATURE_AVX512PF = 0x0400, - FEATURE_AVX512ER = 0x0800, - FEATURE_AVX512CD = 0x1000, - FEATURE_AVX512BW = 0x2000, - FEATURE_AVX512VL = 0x4000, - FEATURE_AVX512VNNI = 0x8000, - FEATURE_AVX512BF16 = 0x10000 + FEATURE_SSE3 = 0x0001, + FEATURE_SSSE3 = 0x0002, + FEATURE_SSE41 = 0x0004, + FEATURE_SSE42 = 0x0008, + FEATURE_AVX = 0x0010, + FEATURE_AVX2 = 0x0020, + FEATURE_FMA3 = 0x0040, + FEATURE_FMA4 = 0x0080, + FEATURE_AVX512F = 0x0100, + FEATURE_AVX512DQ = 0x0200, + FEATURE_AVX512PF = 0x0400, + FEATURE_AVX512ER = 0x0800, + FEATURE_AVX512CD = 0x1000, + FEATURE_AVX512BW = 0x2000, + FEATURE_AVX512VL = 0x4000, + FEATURE_AVX512VNNI = 0x8000, + FEATURE_AVX512BF16 = 0x10000, + FEATURE_AVXVNNI = 0x20000, + FEATURE_AVX512VP2INTERSECT = 0x40000, + FEATURE_MOVDIRI = 0x80000, + FEATURE_MOVDIR64B = 0x100000, + FEATURE_DATAPATH_FP128 = 0x200000, + FEATURE_DATAPATH_FP256 = 0x400000, + FEATURE_DATAPATH_FP512 = 0x800000 +}; + +// To reduce confusion, include MOVU bit so enum values match those in +// CPUID_Fn8000001A_EAX id function. +enum +{ + DATAPATH_UNSET = -1, + DATAPATH_FP128, + DATAPATH_MOVU, + DATAPATH_FP256, + DATAPATH_FP512 }; #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) diff --git a/frame/base/bli_env.c b/frame/base/bli_env.c index 229aae2581..b110127f8b 100644 --- a/frame/base/bli_env.c +++ b/frame/base/bli_env.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -150,6 +150,10 @@ gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) r_val = BLIS_ARCH_PENRYN; } // AMD + else if (strcmp(str, "zen5") == 0) + { + r_val = BLIS_ARCH_ZEN5; + } else if (strcmp(str, "zen4") == 0) { r_val = BLIS_ARCH_ZEN4; @@ -184,44 +188,99 @@ gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) r_val = BLIS_ARCH_BULLDOZER; } // Some aliases for mapping AMD and Intel ISA - // names to a suitable sub-configuration. -#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) || defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_ZEN3) || defined(BLIS_FAMILY_ZEN2) || defined(BLIS_FAMILY_ZEN) + // names to a suitable sub-configuration for each + // x86-64 processor family. +#if defined(BLIS_FAMILY_AMDZEN) else if (strcmp(str, "avx512") == 0) { r_val = BLIS_ARCH_ZEN4; } + else if (strcmp(str, "avx2") == 0) + { + r_val = BLIS_ARCH_ZEN3; + } + else if (strcmp(str, "avx") == 0) + { + r_val = BLIS_ARCH_GENERIC; + } + else if ((strcmp(str, "sse4_2") == 0) || + (strcmp(str, "sse4.2") == 0) || + (strcmp(str, "sse4_1") == 0) || + (strcmp(str, "sse4.1") == 0) || + (strcmp(str, "sse4a") == 0) || + (strcmp(str, "sse4") == 0) || + (strcmp(str, "ssse3") == 0) || + (strcmp(str, "sse3") == 0) || + (strcmp(str, "sse2") == 0)) + { + r_val = BLIS_ARCH_GENERIC; + } #endif -#if defined(BLIS_FAMILY_INTEL64) || defined(BLIS_FAMILY_SKX) || defined(BLIS_FAMILY_HASWELL) +#if defined(BLIS_FAMILY_X86_64) else if (strcmp(str, "avx512") == 0) { - r_val = BLIS_ARCH_SKX; + r_val = BLIS_ARCH_ZEN4; } -#endif -#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) || defined(BLIS_FAMILY_ZEN4) ||defined(BLIS_FAMILY_ZEN3) else if (strcmp(str, "avx2") == 0) { r_val = BLIS_ARCH_ZEN3; } -#endif -#if defined(BLIS_FAMILY_ZEN2) - else if (strcmp(str, "avx2") == 0) + else if (strcmp(str, "avx") == 0) { - r_val = BLIS_ARCH_ZEN2; + r_val = BLIS_ARCH_SANDYBRIDGE; } -#endif -#if defined(BLIS_FAMILY_ZEN) - else if (strcmp(str, "avx2") == 0) + else if ((strcmp(str, "sse4_2") == 0) || + (strcmp(str, "sse4.2") == 0) || + (strcmp(str, "sse4_1") == 0) || + (strcmp(str, "sse4.1") == 0) || + (strcmp(str, "sse4a") == 0) || + (strcmp(str, "sse4") == 0) || + (strcmp(str, "ssse3") == 0) || + (strcmp(str, "sse3") == 0) || + (strcmp(str, "sse2") == 0)) { - r_val = BLIS_ARCH_ZEN; + r_val = BLIS_ARCH_GENERIC; } #endif -#if defined(BLIS_FAMILY_INTEL64) || defined(BLIS_FAMILY_SKX) || defined(BLIS_FAMILY_HASWELL) +#if defined(BLIS_FAMILY_INTEL64) + else if (strcmp(str, "avx512") == 0) + { + r_val = BLIS_ARCH_SKX; + } else if (strcmp(str, "avx2") == 0) { r_val = BLIS_ARCH_HASWELL; } + else if (strcmp(str, "avx") == 0) + { + r_val = BLIS_ARCH_SANDYBRIDGE; + } + else if ((strcmp(str, "sse4_2") == 0) || + (strcmp(str, "sse4.2") == 0) || + (strcmp(str, "sse4_1") == 0) || + (strcmp(str, "sse4.1") == 0) || + (strcmp(str, "sse4a") == 0) || + (strcmp(str, "sse4") == 0) || + (strcmp(str, "ssse3") == 0) || + (strcmp(str, "sse3") == 0) || + (strcmp(str, "sse2") == 0)) + { + r_val = BLIS_ARCH_GENERIC; + } #endif // ARM + else if (strcmp(str, "armsve") == 0) + { + r_val = BLIS_ARCH_ARMSVE; + } + else if (strcmp(str, "a64fx") == 0) + { + r_val = BLIS_ARCH_A64FX; + } + else if (strcmp(str, "firestorm") == 0) + { + r_val = BLIS_ARCH_FIRESTORM; + } else if (strcmp(str, "thunderx2") == 0) { r_val = BLIS_ARCH_THUNDERX2; @@ -313,7 +372,17 @@ gint_t bli_env_get_var_model_type( const char* env, gint_t fallback ) str[i] = tolower(str[i]); } // AMD - if (strcmp(str, "genoa") == 0) + if (strcmp(str, "turin") == 0) + { + r_val = BLIS_MODEL_TURIN; + } + else if ((strcmp(str, "turin_dense") == 0) || + (strcmp(str, "turin-dense") == 0) || + (strcmp(str, "turindense") == 0)) + { + r_val = BLIS_MODEL_TURIN_DENSE; + } + else if (strcmp(str, "genoa") == 0) { r_val = BLIS_MODEL_GENOA; } diff --git a/frame/base/bli_error.c b/frame/base/bli_error.c index 8e60f57039..8853fd43fa 100644 --- a/frame/base/bli_error.c +++ b/frame/base/bli_error.c @@ -36,7 +36,7 @@ #include "blis.h" // Internal array to hold error strings. -static char bli_error_string[BLIS_MAX_NUM_ERR_MSGS][BLIS_MAX_ERR_MSG_LENGTH] = +static char *bli_error_string[-BLIS_ERROR_CODE_MAX] = { [-BLIS_INVALID_ERROR_CHECKING_LEVEL] = "Invalid error checking level.", [-BLIS_UNDEFINED_ERROR_CODE] = "Undefined error code.", @@ -134,11 +134,8 @@ void bli_abort( void ) // ----------------------------------------------------------------------------- -// A mutex to allow synchronous access to bli_err_chk_level. -static bli_pthread_mutex_t err_mutex = BLIS_PTHREAD_MUTEX_INITIALIZER; - // Current error checking level. -static errlev_t bli_err_chk_level = BLIS_FULL_ERROR_CHECKING; +static BLIS_THREAD_LOCAL errlev_t bli_err_chk_level = BLIS_FULL_ERROR_CHECKING; errlev_t bli_error_checking_level( void ) { @@ -152,17 +149,7 @@ void bli_error_checking_level_set( errlev_t new_level ) e_val = bli_check_valid_error_level( new_level ); bli_check_error_code( e_val ); - // Acquire the mutex protecting bli_err_chk_level. - bli_pthread_mutex_lock( &err_mutex ); - - // BEGIN CRITICAL SECTION - { - bli_err_chk_level = new_level; - } - // END CRITICAL SECTION - - // Release the mutex protecting bli_err_chk_level. - bli_pthread_mutex_unlock( &err_mutex ); + bli_err_chk_level = new_level; } bool bli_error_checking_is_enabled( void ) diff --git a/frame/base/bli_gks.c b/frame/base/bli_gks.c index 321d725554..943988a246 100644 --- a/frame/base/bli_gks.c +++ b/frame/base/bli_gks.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -50,7 +50,7 @@ static void_fp cntx_ref_init[ BLIS_NUM_ARCHS ]; // Define a function pointer type for context initialization functions. typedef void (*nat_cntx_init_ft)( cntx_t* cntx ); typedef void (*ref_cntx_init_ft)( cntx_t* cntx ); -typedef void (*ind_cntx_init_ft)( ind_t method, num_t dt, cntx_t* cntx ); +typedef void (*ind_cntx_init_ft)( ind_t method, cntx_t* cntx ); // ----------------------------------------------------------------------------- @@ -107,6 +107,11 @@ void bli_gks_init( void ) #endif // AMD architectures +#ifdef BLIS_CONFIG_ZEN5 + bli_gks_register_cntx( BLIS_ARCH_ZEN5, bli_cntx_init_zen5, + bli_cntx_init_zen5_ref, + bli_cntx_init_zen5_ind ); +#endif #ifdef BLIS_CONFIG_ZEN4 bli_gks_register_cntx( BLIS_ARCH_ZEN4, bli_cntx_init_zen4, bli_cntx_init_zen4_ref, @@ -149,6 +154,11 @@ void bli_gks_init( void ) #endif // ARM architectures +#ifdef BLIS_CONFIG_A64FX + bli_gks_register_cntx( BLIS_ARCH_A64FX, bli_cntx_init_a64fx, + bli_cntx_init_a64fx_ref, + bli_cntx_init_a64fx_ind ); +#endif #ifdef BLIS_CONFIG_THUNDERX2 bli_gks_register_cntx( BLIS_ARCH_THUNDERX2, bli_cntx_init_thunderx2, bli_cntx_init_thunderx2_ref, @@ -174,6 +184,11 @@ void bli_gks_init( void ) bli_cntx_init_a64fx_ref, bli_cntx_init_a64fx_ind ); #endif +#ifdef BLIS_CONFIG_FIRESTORM + bli_gks_register_cntx( BLIS_ARCH_FIRESTORM, bli_cntx_init_firestorm, + bli_cntx_init_firestorm_ref, + bli_cntx_init_firestorm_ind ); +#endif #ifdef BLIS_CONFIG_CORTEXA15 bli_gks_register_cntx( BLIS_ARCH_CORTEXA15, bli_cntx_init_cortexa15, bli_cntx_init_cortexa15_ref, @@ -622,7 +637,7 @@ cntx_t* bli_gks_query_ind_cntx // function for the current induced method. (That function assumes // that the context is pre- initialized with values for native // execution.) - f( ind, dt, gks_id_ind ); + f( ind, gks_id_ind ); } } // END CRITICAL SECTION diff --git a/frame/ind/bli_ind.c b/frame/base/bli_ind.c similarity index 85% rename from frame/ind/bli_ind.c rename to frame/base/bli_ind.c index 28fb44669d..a359e89a38 100644 --- a/frame/ind/bli_ind.c +++ b/frame/base/bli_ind.c @@ -36,11 +36,6 @@ static char* bli_ind_impl_str[BLIS_NUM_IND_METHODS] = { -/* 3mh */ "3mh", -/* 3m1 */ "3m1", -/* 4mh */ "4mh", -/* 4m1b */ "4m1b", -/* 4m1a */ "4m1a", /* 1m */ "1m", /* nat */ "native", }; @@ -147,8 +142,9 @@ bool bli_ind_oper_is_impl( opid_t oper, ind_t method ) if ( bli_opid_is_level3( oper ) ) { - // Look up whether its func_t pointer in the table is NULL. - is_impl = ( bli_l3_ind_oper_get_func( oper, method ) != NULL ); + // Look up whether the operation is implemented for the given induced + // method id. + is_impl = bli_l3_ind_oper_is_impl( oper, method ); } else { @@ -162,39 +158,6 @@ bool bli_ind_oper_is_impl( opid_t oper, ind_t method ) return is_impl; } -#if 0 -bool bli_ind_oper_has_avail( opid_t oper, num_t dt ) -{ - ind_t method = bli_ind_oper_find_avail( oper, dt ); - - if ( method == BLIS_NAT ) return FALSE; - else return TRUE; -} -#endif - -void_fp bli_ind_oper_get_avail( opid_t oper, num_t dt ) -{ - void_fp func_p; - - if ( bli_opid_is_level3( oper ) ) - { - ind_t method = bli_ind_oper_find_avail( oper, dt ); - - func_p = bli_l3_ind_oper_get_func( oper, method ); - } - else - { - // Currently, any operation that is not level-3 does not - // have induced method implementations. (This should actually - // assign the pointer to be the native front-end, but for - // now there are no calls to bli_ind_oper_get_avail() in the - // context of level-2 operations. - func_p = NULL; - } - - return func_p; -} - ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ) { ind_t method; diff --git a/frame/ind/bli_ind.h b/frame/base/bli_ind.h similarity index 89% rename from frame/ind/bli_ind.h rename to frame/base/bli_ind.h index 57bd6e5c59..85cad648e9 100644 --- a/frame/ind/bli_ind.h +++ b/frame/base/bli_ind.h @@ -38,16 +38,6 @@ // level-3 induced method management #include "bli_l3_ind.h" -// level-3 object APIs -#include "bli_l3_ind_oapi.h" - -// level-3 typed APIs -#include "bli_l3_ind_tapi.h" - -// level-3 cntx initialization -#include "bli_cntx_ind_stage.h" - - void bli_ind_init( void ); void bli_ind_finalize( void ); @@ -62,8 +52,6 @@ BLIS_EXPORT_BLIS void bli_ind_disable_all_dt( num_t dt ); BLIS_EXPORT_BLIS void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); BLIS_EXPORT_BLIS bool bli_ind_oper_is_impl( opid_t oper, ind_t method ); -//bool bli_ind_oper_has_avail( opid_t oper, num_t dt ); -BLIS_EXPORT_BLIS void_fp bli_ind_oper_get_avail( opid_t oper, num_t dt ); BLIS_EXPORT_BLIS ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ); BLIS_EXPORT_BLIS char* bli_ind_oper_get_avail_impl_string( opid_t oper, num_t dt ); diff --git a/frame/base/bli_init.c b/frame/base/bli_init.c index ed0567f3cc..511fbe7f85 100644 --- a/frame/base/bli_init.c +++ b/frame/base/bli_init.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_init.h b/frame/base/bli_init.h index f174ac0f99..9cf2378ca4 100644 --- a/frame/base/bli_init.h +++ b/frame/base/bli_init.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,8 +36,8 @@ BLIS_EXPORT_BLIS void bli_init( void ); BLIS_EXPORT_BLIS void bli_finalize( void ); -void bli_init_auto( void ); -void bli_finalize_auto( void ); +BLIS_EXPORT_BLIS void bli_init_auto( void ); +BLIS_EXPORT_BLIS void bli_finalize_auto( void ); void bli_init_apis( void ); void bli_finalize_apis( void ); diff --git a/frame/base/bli_pba.h b/frame/base/bli_pba.h index 23e35452d0..cbb57de9ac 100644 --- a/frame/base/bli_pba.h +++ b/frame/base/bli_pba.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -144,7 +144,7 @@ void bli_pba_release mem_t* mem ); -void bli_pba_rntm_set_pba +BLIS_EXPORT_BLIS void bli_pba_rntm_set_pba ( rntm_t* rntm ); diff --git a/frame/base/bli_query.c b/frame/base/bli_query.c index c62a30cccd..454f17d191 100644 --- a/frame/base/bli_query.c +++ b/frame/base/bli_query.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +37,7 @@ bool bli_obj_equals( obj_t* a, obj_t* b ) { -#if 0 +#if 1 bool r_val = FALSE; num_t dt_a; num_t dt_b; @@ -45,7 +46,15 @@ bool bli_obj_equals( obj_t* a, obj_t* b ) // The function is not yet implemented for vectors and matrices. if ( !bli_obj_is_1x1( a ) || !bli_obj_is_1x1( b ) ) - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); + { + + if ( bli_obj_is_vector( a ) && bli_obj_is_vector( b ) ) + bli_eqv( a, b, &r_val ); + else + bli_eqm( a, b, &r_val ); + + return r_val; + } dt_a = bli_obj_dt( a ); dt_b = bli_obj_dt( b ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 91d3b5753e..85f6ec1776 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -575,7 +575,10 @@ void bli_nthreads_optimum( dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if(bli_arch_query_id() == BLIS_ARCH_ZEN4) + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + if(id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { if(n < m) { @@ -1138,7 +1141,7 @@ void bli_nthreads_optimum( } } } - else + else // Not BLIS_ARCH_ZEN5 or BLIS_ARCH_ZEN4 { if( k >= 128) { @@ -1283,13 +1286,21 @@ void bli_nthreads_optimum( { dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); - +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; +#else if((m>=64) && (m<=256) && (n>=64) && (n<=256)) { n_threads_ideal = 8; } +#endif } - else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) + else if( family == BLIS_GEMMT && ( bli_obj_is_double(c) || bli_obj_is_dcomplex(c) ) ) { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); @@ -1587,7 +1598,7 @@ err_t bli_smart_threading_sup 1. For non-Zen architectures, return -1. The expectation is that this is handled in the higher layer */ -static void aocl_dscalv_dynamic +BLIS_INLINE void aocl_dscalv_dynamic ( arch_t arch_id, dim_t n_elem, @@ -1601,6 +1612,7 @@ static void aocl_dscalv_dynamic */ switch (arch_id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: @@ -1612,9 +1624,9 @@ static void aocl_dscalv_dynamic *nt_ideal = 2; else if (n_elem <= 500000) *nt_ideal = 8; - else if (n_elem <= 4000000) - *nt_ideal = 12; else if (n_elem <= 2500000) + *nt_ideal = 12; + else if (n_elem <= 4000000) *nt_ideal = 16; else if(n_elem <= 7000000) *nt_ideal = 24; @@ -1661,7 +1673,7 @@ static void aocl_dscalv_dynamic 1. For non-Zen architectures, return -1. The expectation is that this is handled in the higher layer */ -static void aocl_zdscalv_dynamic +BLIS_INLINE void aocl_zdscalv_dynamic ( arch_t arch_id, dim_t n_elem, @@ -1675,6 +1687,7 @@ static void aocl_zdscalv_dynamic */ switch (arch_id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: @@ -1731,7 +1744,7 @@ static void aocl_zdscalv_dynamic 1. For non-Zen architectures, return -1. The expectation is that this is handled in the higher layer */ -static void aocl_daxpyv_dynamic +BLIS_INLINE void aocl_daxpyv_dynamic ( arch_t arch_id, dim_t n_elem, @@ -1744,7 +1757,46 @@ static void aocl_daxpyv_dynamic */ switch (arch_id) { + case BLIS_ARCH_ZEN5: + + if ( n_elem <= 34000 ) + *nt_ideal = 1; + else if ( n_elem <= 82000 ) + *nt_ideal = 4; + else if ( n_elem <= 2330000 ) + *nt_ideal = 8; + else if ( n_elem <= 4250000 ) + *nt_ideal = 16; + else if ( n_elem <= 7000000 ) + *nt_ideal = 32; + else if ( n_elem <= 21300000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + case BLIS_ARCH_ZEN4: + + if ( n_elem <= 11000 ) + *nt_ideal = 1; + else if ( n_elem <= 130000 ) + *nt_ideal = 4; + else if ( n_elem <= 2230000 ) + *nt_ideal = 8; + else if ( n_elem <= 3400000 ) + *nt_ideal = 16; + else if ( n_elem <= 9250000 ) + *nt_ideal = 32; + else if ( n_elem <= 15800000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN3: @@ -1780,6 +1832,104 @@ static void aocl_daxpyv_dynamic } } +/* + Functionality: + -------------- + This function decides the AOCL dynamic logic for L1 zaxpyv API based on the + architecture ID and size of the input variable. + + Function signature + ------------------- + + This function takes the following input: + + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + The function has been made static to restrict its scope. + + Exception + ---------- + + 1. For non-Zen architectures, return -1. The expectation is that this is handled + in the higher layer +*/ +BLIS_INLINE void aocl_zaxpyv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + /* + Pick the AOCL dynamic logic based on the + architecture ID + */ + switch (arch_id) + { + case BLIS_ARCH_ZEN5: + + if ( n_elem <= 16000 ) + *nt_ideal = 1; + else if (n_elem <= 43000) + *nt_ideal = 4; + else if (n_elem <= 2300000) + *nt_ideal = 8; + else if (n_elem <= 4000000) + *nt_ideal = 32; + else if (n_elem <= 6600000) + *nt_ideal = 64; + else if (n_elem <= 6600000) + *nt_ideal = 96; + else + *nt_ideal = 128; + break; + + case BLIS_ARCH_ZEN4: + + if ( n_elem <= 4600 ) + *nt_ideal = 1; + else if (n_elem <= 6700) + *nt_ideal = 2; + else if (n_elem <= 61500) + *nt_ideal = 4; + else if (n_elem <= 1200000) + *nt_ideal = 8; + else if (n_elem <= 4000000) + *nt_ideal = 32; + else + *nt_ideal = 96; + break; + + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + + if ( n_elem <= 2600 ) + *nt_ideal = 1; + else if( n_elem <= 11000) + *nt_ideal = 2; + else if (n_elem <= 33000) + *nt_ideal = 4; + else + // Performance does not scale with number of threads beyond 8 threads + *nt_ideal = 8; + break; + + default: + /* + Without this default condition, compiler will throw + a warning saying other conditions are not handled + */ + + /* + For other architectures, AOCL dynamic does not make any change + */ + *nt_ideal = -1; + } +} + /* Functionality: -------------- @@ -1803,7 +1953,7 @@ static void aocl_daxpyv_dynamic 1. For non-Zen architectures, return -1. The expectation is that this is handled in the higher layer */ -static void aocl_ddotv_dynamic +BLIS_INLINE void aocl_ddotv_dynamic ( arch_t arch_id, dim_t n_elem, @@ -1816,6 +1966,7 @@ static void aocl_ddotv_dynamic */ switch (arch_id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: @@ -1850,103 +2001,693 @@ static void aocl_ddotv_dynamic } } -#endif // AOCL_DYNAMIC +BLIS_INLINE void aocl_zdotv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + /* + Pick the AOCL dynamic logic based on the + architecture ID + */ + switch (arch_id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + // @note: Further tuning can be done. + if ( n_elem <= 2080 ) + *nt_ideal = 1; + else if (n_elem <= 3328 ) + *nt_ideal = 4; + else if (n_elem <= 98304) + *nt_ideal = 8; + else if (n_elem <= 262144) + *nt_ideal = 32; + else if (n_elem <= 524288) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + + default: + /* + Without this default condition, compiler will throw + a warning saying other conditions are not handled + */ + + /* + For other architectures, AOCL dynamic does not make any change + */ + *nt_ideal = -1; + } +} /* Functionality: -------------- - - This function does the following: - 1. Reads the number of threads requested by the user from the rntm variable - 2. Acts as the gateway to the AOCL dynamic logic if AOCL dynamic is enabled - and alters the count of the number of threads accordingly + This function decides the AOCL dynamic logic for L1 dcopyv API based on the + architecture ID, input type and size of the input variable. Function signature ------------------- This function takes the following input: - * 'ker_id' - ID of kernel invoking this function - * 'datatype_a' - Datatype 1 of kernel - * 'datatype_b' - Datatype 2 of kernel * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) * 'n_elem' - Number of elements in the vector * 'nt_ideal' - Ideal number of threads + The function has been made static to restrict its scope. + Exception ---------- - None + 1. For non-Zen architectures, return -1. The expectation is that this is handled + in the higher layer */ -void bli_nthreads_l1 + +BLIS_INLINE void aocl_dcopyv_dynamic ( - l1vkr_t ker_id, - num_t data_type_a, - num_t data_type_b, - arch_t arch_id, - dim_t n_elem, - dim_t* nt_ideal + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal ) { -#ifdef AOCL_DYNAMIC - /* - This code sections dispatches the AOCL dynamic logic kernel for - L1 APIs based on the kernel ID and the data type. - */ - // Function pointer to AOCL Dynamic logic kernel - void (*aocl_dynamic_func_l1)(arch_t, dim_t, dim_t* ) = NULL; + // Pick the AOCL dynamic logic based on the + // architecture ID - // Pick the aocl dynamic thread decision kernel based on the kernel ID - switch (ker_id) + switch (arch_id) { - case BLIS_SCALV_KER: + case BLIS_ARCH_ZEN5: - /* - When input data types do not match the call is from mixed precision - */ - if (data_type_a != data_type_b) - { - // Function for ZDSCALV - aocl_dynamic_func_l1 = aocl_zdscalv_dynamic; - } + if ( n_elem <= 39000 ) + *nt_ideal = 1; + else if ( n_elem <= 46000 ) + *nt_ideal = 2; + else if (n_elem <= 160000) + *nt_ideal = 4; else - { - // Function for DSCALV - aocl_dynamic_func_l1 = aocl_dscalv_dynamic; - } - + *nt_ideal = 8; + // dcopy does not scale with more than 8 threads break; - case BLIS_AXPYV_KER: - - // Function for DAXPYV - aocl_dynamic_func_l1 = aocl_daxpyv_dynamic; + case BLIS_ARCH_ZEN4: + if ( n_elem <= 17000 ) + *nt_ideal = 1; + else if (n_elem <= 62000) + *nt_ideal = 2; + else if (n_elem <= 96000) + *nt_ideal = 4; + else + *nt_ideal = 8; + // dcopy does not scale with more than 8 threads break; + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: - case BLIS_DOTV_KER: - - // Function for DDOTV - aocl_dynamic_func_l1 = aocl_ddotv_dynamic; - + if ( n_elem <= 17000 ) + *nt_ideal = 1; + else if (n_elem <= 52200) + *nt_ideal = 4; + else + *nt_ideal = 8; + // dcopy does not scale with more than 8 threads break; default: - /* - For kernels that do no have AOCL dynamic logic, - use the number of threads requested by the user. - */ + // Without this default condition, compiler will throw + // a warning saying other conditions are not handled + // For other architectures, AOCL dynamic does not make any change *nt_ideal = -1; } +} - /* - For APIs that do not have AOCL dynamic - logic, aocl_dynamic_func_l1 will be NULL. - */ - if( aocl_dynamic_func_l1 != NULL) - { - // Call the AOCL dynamic logic kernel - aocl_dynamic_func_l1 +/* + Functionality: + -------------- + This function decides the AOCL dynamic logic for L1 zcopyv API based on the + architecture ID, input type and size of the input variable. + + Function signature + ------------------- + + This function takes the following input: + + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + The function has been made static to restrict its scope. + + Exception + ---------- + + 1. For non-Zen architectures, return -1. The expectation is that this is handled + in the higher layer +*/ + +BLIS_INLINE void aocl_zcopyv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + // Pick the AOCL dynamic logic based on the + // architecture ID + + switch (arch_id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + + if ( n_elem <= 4600 ) + *nt_ideal = 1; + else if (n_elem <= 5100) + *nt_ideal = 2; + else if (n_elem <= 22000) + *nt_ideal = 4; + else if (n_elem <= 240000) + *nt_ideal = 8; + else if (n_elem <=380000) + *nt_ideal = 16; + else if (n_elem <= 1700000) + *nt_ideal = 32; + else if (n_elem <= 3700000) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + + default: + // Without this default condition, compiler will throw + // a warning saying other conditions are not handled + + // For other architectures, AOCL dynamic does not make any change + *nt_ideal = -1; + } +} + +/* + Functionality: + -------------- + This function decides the AOCL dynamic logic for L1 dnormfv API based on the + architecture ID and size of the input variable. + + Function signature + ------------------- + + This function takes the following input: + + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + Exception + ---------- + + 1. For non-Zen architectures, return -1. The expectation is that this is handled + in the higher layer +*/ +void aocl_dnormfv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + /* + Pick the AOCL dynamic logic based on the + architecture ID + */ + switch ( arch_id ) + { + case BLIS_ARCH_ZEN5: + + #ifdef __clang__ + // Threshold setting based on LLVM's OpenMP + if ( n_elem < 6000 ) + *nt_ideal = 1; + else if ( n_elem < 16900 ) + *nt_ideal = 4; + else if ( n_elem < 126000 ) + *nt_ideal = 8; + else if ( n_elem < 200000 ) + *nt_ideal = 16; + else if ( n_elem < 250000 ) + *nt_ideal = 32; + else if ( n_elem < 500000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + #else + // Threshold setting based on GNU's OpenMP + if ( n_elem < 4500 ) + *nt_ideal = 1; + else if ( n_elem < 15400 ) + *nt_ideal = 4; + else if ( n_elem < 285000 ) + *nt_ideal = 8; + else if ( n_elem < 604000 ) + *nt_ideal = 16; + else if ( n_elem < 2780000 ) + *nt_ideal = 32; + else if ( n_elem < 10500000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + #endif + + break; + + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + + if ( n_elem < 4000 ) + *nt_ideal = 1; + else if ( n_elem < 17000 ) + *nt_ideal = 4; + else if ( n_elem < 136000 ) + *nt_ideal = 8; + else if ( n_elem < 365000 ) + *nt_ideal = 16; + else if ( n_elem < 2950000 ) + *nt_ideal = 32; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + + default: + /* + Without this default condition, compiler will throw + a warning saying other conditions are not handled + */ + + /* + For other architectures, AOCL dynamic does not make any change + */ + *nt_ideal = -1; + } +} + +/* + Functionality: + -------------- + This function decides the AOCL dynamic logic for L1 znormfv API based on the + architecture ID and size of the input variable. + + Function signature + ------------------- + + This function takes the following input: + + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + Exception + ---------- + + 1. For non-Zen architectures, return -1. The expectation is that this is handled + in the higher layer +*/ +void aocl_znormfv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + /* + Pick the AOCL dynamic logic based on the + architecture ID + */ + switch ( arch_id ) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + + if ( n_elem < 2000 ) + *nt_ideal = 1; + else if ( n_elem < 6500 ) + *nt_ideal = 4; + else if ( n_elem < 71000 ) + *nt_ideal = 8; + else if ( n_elem < 200000 ) + *nt_ideal = 16; + else if ( n_elem < 1530000 ) + *nt_ideal = 32; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + + default: + /* + Without this default condition, compiler will throw + a warning saying other conditions are not handled + */ + + /* + For other architectures, AOCL dynamic does not make any change + */ + *nt_ideal = -1; + } +} + +static void aocl_daxpyf_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ + + // Pick the AOCL dynamic logic based on the + // architecture ID + + switch (arch_id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + + if ( n_elem <= 128 ) + *nt_ideal = 1; + // these nt_ideal sizes are tuned for trsv only, + // when axpyf kernels are enabled for gemv, these might need + // to be re tuned + + // else if ( n_elem <= 224) + // *nt_ideal = 2; + // else if ( n_elem <= 860) + // *nt_ideal = 4; + else + *nt_ideal = 8; + // axpyf does not scale with more than 8 threads + + break; + + default: + /* + Without this default condition, compiler will throw + a warning saying other conditions are not handled + */ + + /* + For other architectures, AOCL dynamic does not make any change + */ + *nt_ideal = -1; + } +} + +#endif // AOCL_DYNAMIC + +/* + Functionality: + -------------- + + This function does the following: + 1. Reads the number of threads requested by the user from the rntm variable + 2. Acts as the gateway to the AOCL dynamic logic if AOCL dynamic is enabled + and alters the count of the number of threads accordingly + + Function signature + ------------------- + + This function takes the following input: + + * 'ker_id' - ID of kernel invoking this function + * 'datatype_a' - Datatype 1 of kernel + * 'datatype_b' - Datatype 2 of kernel + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + Exception + ---------- + + None +*/ +void bli_nthreads_l1 + ( + l1vkr_t ker_id, + num_t data_type_a, + num_t data_type_b, + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ +#ifdef AOCL_DYNAMIC + /* + This code sections dispatches the AOCL dynamic logic kernel for + L1 APIs based on the kernel ID and the data type. + */ + // Function pointer to AOCL Dynamic logic kernel + void (*aocl_dynamic_func_l1)(arch_t, dim_t, dim_t* ) = NULL; + + // Pick the aocl dynamic thread decision kernel based on the kernel ID + switch (ker_id) + { + case BLIS_SCALV_KER: + + /* + When input data types do not match the call is from mixed precision + */ + if (data_type_a != data_type_b) + { + // Function for ZDSCALV + aocl_dynamic_func_l1 = aocl_zdscalv_dynamic; + } + else + { + // Function for DSCALV + aocl_dynamic_func_l1 = aocl_dscalv_dynamic; + } + + break; + + case BLIS_AXPYV_KER: + + if ( data_type_a == BLIS_DOUBLE ) + { + // Function for DAXPYV + aocl_dynamic_func_l1 = aocl_daxpyv_dynamic; + } + else if ( data_type_a == BLIS_DCOMPLEX ) + { + // Function for ZAXPYV + aocl_dynamic_func_l1 = aocl_zaxpyv_dynamic; + } + break; + + case BLIS_DOTV_KER: + + if ( data_type_a == BLIS_DOUBLE ) + { + // Function for DDOTV + aocl_dynamic_func_l1 = aocl_ddotv_dynamic; + } + else if ( data_type_a == BLIS_DCOMPLEX ) + { + // Function for ZDOTV + aocl_dynamic_func_l1 = aocl_zdotv_dynamic; + } + + break; + + case BLIS_COPYV_KER: + + if ( data_type_a == BLIS_DOUBLE) + { + // Function for DCOPYV + aocl_dynamic_func_l1 = aocl_dcopyv_dynamic; + } + else if ( data_type_a == BLIS_DCOMPLEX ) + { + // Function for ZCOPYV + aocl_dynamic_func_l1 = aocl_zcopyv_dynamic; + } + break; + + default: + /* + For kernels that do no have AOCL dynamic logic, + use the number of threads requested by the user. + */ + *nt_ideal = -1; + } + + /* + For APIs that do not have AOCL dynamic + logic, aocl_dynamic_func_l1 will be NULL. + */ + if( aocl_dynamic_func_l1 != NULL) + { + // Call the AOCL dynamic logic kernel + aocl_dynamic_func_l1 + ( + arch_id, + n_elem, + nt_ideal + ); + + if (*nt_ideal == 1) + { + // Return early when the number of threads is 1 + return; + } + } + +#endif + // Initialized to avoid compiler warning + rntm_t rntm_local; + + // Initialize a local runtime with global settings. + bli_rntm_init_from_global(&rntm_local); + + // Query the total number of threads from the rntm_t object. + dim_t nt_rntm = bli_rntm_num_threads(&rntm_local); + + if (nt_rntm <= 0) + { + // nt is less than one if BLIS manual setting of parallelism + // has been used. Parallelism here will be product of values. + nt_rntm = bli_rntm_calc_num_threads(&rntm_local); + } + +#ifdef AOCL_DYNAMIC + + // Calculate the actual number of threads that will be spawned + if (*nt_ideal != -1) + { + // The if block is executed for all Zen architectures + *nt_ideal = bli_min(nt_rntm, *nt_ideal); + } + else + { + /* + For non-Zen architectures and very large sizes, + spawn the actual number of threads requested + */ + *nt_ideal = nt_rntm; + } + + /* + When the number of element to be processed is less + than the number of threads spawn n_elem number of threads. + */ + if (n_elem < *nt_ideal) + { + *nt_ideal = n_elem; + } +#else + + // Calculate the actual number of threads that will be spawned + *nt_ideal = nt_rntm; + +#endif +} + +/* + Functionality: + -------------- + + This function does the following: + 1. Reads the number of threads requested by the user from the rntm variable + 2. Acts as the gateway to the AOCL dynamic logic if AOCL dynamic is enabled + and alters the count of the number of threads accordingly + + Function signature + ------------------- + + This function takes the following input: + + * 'ker_id' - ID of kernel invoking this function + * 'datatype_a' - Datatype 1 of kernel + * 'datatype_b' - Datatype 2 of kernel + * 'arch_id' - Architecture ID of the system (copy of BLIS global arch id) + * 'n_elem' - Number of elements in the vector + * 'nt_ideal' - Ideal number of threads + + Exception + ---------- + + None +*/ +void bli_nthreads_l1f + ( + l1fkr_t ker_id, + num_t data_type_a, + num_t data_type_b, + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ) +{ +#ifdef AOCL_DYNAMIC + /* + This code sections dispatches the AOCL dynamic logic kernel for + L1 APIs based on the kernel ID and the data type. + */ + // Function pointer to AOCL Dynamic logic kernel + void (*aocl_dynamic_func_l1f)(arch_t, dim_t, dim_t* ) = NULL; + + // Pick the aocl dynamic thread decision kernel based on the kernel ID + switch (ker_id) + { + case BLIS_AXPYF_KER: + + if ( data_type_a == BLIS_DOUBLE ) + { + // Function for DAXPYF + aocl_dynamic_func_l1f = aocl_daxpyf_dynamic; + } + break; + + default: + /* + For kernels that do no have AOCL dynamic logic, + use the number of threads requested by the user. + */ + *nt_ideal = -1; + } + + /* + For APIs that do not have AOCL dynamic + logic, aocl_dynamic_func_l1f will be NULL. + */ + if( aocl_dynamic_func_l1f != NULL) + { + // Call the AOCL dynamic logic kernel + aocl_dynamic_func_l1f ( arch_id, n_elem, diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 5df21f811e..344bac9f3b 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,7 +63,7 @@ void bli_rntm_set_ways_from_rntm rntm_t* rntm ); -void bli_rntm_set_ways_from_rntm_sup +BLIS_EXPORT_BLIS void bli_rntm_set_ways_from_rntm_sup ( dim_t m, dim_t n, @@ -83,7 +83,7 @@ dim_t bli_rntm_calc_num_threads_in ); #ifdef AOCL_DYNAMIC -void bli_nthreads_optimum +BLIS_EXPORT_BLIS void bli_nthreads_optimum ( obj_t* a, obj_t* b, @@ -103,6 +103,20 @@ err_t bli_smart_threading_sup ); #endif +void aocl_dnormfv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ); + +void aocl_znormfv_dynamic + ( + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ); + void bli_nthreads_l1 ( l1vkr_t ker_id, @@ -113,6 +127,16 @@ void bli_nthreads_l1 dim_t* nt_ideal ); +void bli_nthreads_l1f + ( + l1fkr_t ker_id, + num_t data_type_a, + num_t data_type_b, + arch_t arch_id, + dim_t n_elem, + dim_t* nt_ideal + ); + // Runtime object type (defined in bli_type_defs.h) /* diff --git a/frame/base/bli_sba.h b/frame/base/bli_sba.h index 63e48200c5..74e67f55df 100644 --- a/frame/base/bli_sba.h +++ b/frame/base/bli_sba.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,17 +42,17 @@ apool_t* bli_sba_query( void ); void bli_sba_init( void ); void bli_sba_finalize( void ); -array_t* bli_sba_checkout_array +BLIS_EXPORT_BLIS array_t* bli_sba_checkout_array ( const siz_t n_threads ); -void bli_sba_checkin_array +BLIS_EXPORT_BLIS void bli_sba_checkin_array ( array_t* restrict array ); -void bli_sba_rntm_set_pool +BLIS_EXPORT_BLIS void bli_sba_rntm_set_pool ( siz_t index, array_t* restrict array, diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index 8036237d71..7302ec5969 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -62,7 +62,13 @@ f77_int PASTEF772S(i,chx,blasname) \ being returned, which is not what we want. */ \ if ( *n < 1 || *incx <= 0 ) { \ AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: vector empty") \ - return 0; \ + return 0; \ + }\ +\ + /* If n=1, return 1 here to emulate netlib BLAS and avoid touching vector */ \ + if ( *n == 1 ) { \ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: n=1") \ + return 1; \ }\ \ /* Initialize BLIS. */ \ diff --git a/frame/compat/bla_amax.h b/frame/compat/bla_amax.h index 0a7cee7f2c..4a9a4acee9 100644 --- a/frame/compat/bla_amax.h +++ b/frame/compat/bla_amax.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c index 4e7b4fb22f..bf5abf735a 100644 --- a/frame/compat/bla_amax_amd.c +++ b/frame/compat/bla_amax_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -62,7 +62,13 @@ f77_int PASTEF772S(i,chx,blasname) \ being returned, which is not what we want. */ \ if ( *n < 1 || *incx <= 0 ) { \ AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: vector empty") \ - return 0; \ + return 0; \ + }\ +\ + /* If n=1, return 1 here to emulate netlib BLAS and avoid touching vector */ \ + if ( *n == 1 ) { \ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: n=1") \ + return 1; \ }\ \ /* Initialize BLIS. */ \ @@ -133,6 +139,12 @@ f77_int isamax_blis_impl return 0; } + /* If n=1, return 1 here to emulate netlib BLAS and avoid touching vector */ + if ( *n == 1 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: n=1"); + return 1; + } + /* Initialize BLIS. */ // bli_init_auto(); @@ -242,6 +254,12 @@ f77_int idamax_blis_impl return 0; } + /* If n=1, return 1 here to emulate netlib BLAS and avoid touching vector */ + if ( *n == 1 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: n=1"); + return 1; + } + /* When the length of the vector is one it is going to be the element with the maximum absolute value. This early return condition is defined in @@ -301,6 +319,7 @@ f77_int idamax_blis_impl // Pick the kernel based on the architecture ID switch (id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: diff --git a/frame/compat/bla_amin.c b/frame/compat/bla_amin.c index ada7a899eb..520b25c34e 100644 --- a/frame/compat/bla_amin.c +++ b/frame/compat/bla_amin.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_asum.c b/frame/compat/bla_asum.c index 1ad70d1944..5ec4d61eea 100644 --- a/frame/compat/bla_asum.c +++ b/frame/compat/bla_asum.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,8 +38,8 @@ // // Define BLAS-to-BLIS interfaces. // -#undef GENTFUNCR2 -#define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ +#undef GENTFUNCR3 +#define GENTFUNCR3( ftype_x, ftype_r, chx, chr, chru, blasname, blisname ) \ \ ftype_r PASTEF772S(chr,chx,blasname) \ ( \ @@ -53,6 +53,14 @@ ftype_r PASTEF772S(chr,chx,blasname) \ ftype_x* x0; \ inc_t incx0; \ ftype_r asum; \ +\ + asum = *PASTEMAC(chru,0); \ +\ + /* Early return scenarios */ \ + if ( *n < 1 || *incx <= 0 ) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + return asum; \ + }\ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -92,5 +100,5 @@ ftype_r PASTEF772(chr,chx,blasname) \ } \ ) -INSERT_GENTFUNCR2_BLAS( asum, asumv ) +INSERT_GENTFUNCR3_BLAS( asum, asumv ) diff --git a/frame/compat/bla_asum.h b/frame/compat/bla_asum.h index b3bc565c7f..b9e1e472f4 100644 --- a/frame/compat/bla_asum.h +++ b/frame/compat/bla_asum.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_axpby.h b/frame/compat/bla_axpby.h index c8c384d01a..cb95788c3e 100644 --- a/frame/compat/bla_axpby.h +++ b/frame/compat/bla_axpby.h @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_axpby_amd.c b/frame/compat/bla_axpby_amd.c new file mode 100644 index 0000000000..7e935433d5 --- /dev/null +++ b/frame/compat/bla_axpby_amd.c @@ -0,0 +1,638 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77S(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPBY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, (void*)beta, *incy) \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +}\ +\ +IF_BLIS_ENABLE_BLAS(\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( n, alpha, x, incx, beta, y, incy ); \ +} \ +) + +void saxpby_blis_impl +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float *)alpha, *incx, *incy) + + /* Early exit in case n is 0, or alpha is 0 and beta is 1 */ + if ( ( *n <= 0 ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq1 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + dim_t n0; + float *x0; + float *y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + n0 = ( dim_t )( *n ); + + /* + If the input increments are negative, adjust the pointers so we can + use positive increments instead. + */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ( ( float * )x ) + ( n0 - 1 ) * ( -( *incx ) ); + incx0 = ( inc_t )( *incx ); + } + else + { + x0 = ( ( float* )x ); + incx0 = ( inc_t )( *incx ); + } + if ( *incy < 0 ) + { + y0 = ( ( float* )y ) + ( n0 - 1 ) * ( -( *incy ) ); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ( ( float* )y ); + incy0 = ( inc_t )( *incy ); + } + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + /* + Function pointer declaration for the function + that will be used by this API + */ + saxpbyv_ker_ft axpbyv_ker_ptr; // DAXPBYV + + // Pick the kernel based on the architecture ID + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + axpbyv_ker_ptr = bli_saxpbyv_zen_int10; + + break; + default: + + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for saxpbyv + axpbyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_FLOAT, BLIS_AXPBYV_KER, cntx); + } + + // Call the function based on the function pointer assigned above + axpbyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + (float *)alpha, + x0, incx0, + (float *)beta, + y0, incy0, + cntx + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +#ifdef BLIS_ENABLE_BLAS +void saxpby_ +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy +) +{ + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy ) ; +} +#endif + +//------------------------------------------------------------------------- + +void daxpby_blis_impl +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double *)alpha, *incx, *incy) + + /* Early exit in case n is 0, or alpha is 0 and beta is 1 */ + if ( ( *n <= 0 ) || + ( PASTEMAC( d, eq0 )( *alpha ) && PASTEMAC( d, eq1 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + dim_t n0; + double *x0; + double *y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + n0 = ( dim_t )( *n ); + + /* + If the input increments are negative, adjust the pointers so we can + use positive increments instead. + */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ( ( double * )x ) + ( n0 - 1 ) * ( -( *incx ) ); + incx0 = ( inc_t )( *incx ); + } + else + { + x0 = ( ( double* )x ); + incx0 = ( inc_t )( *incx ); + } + if ( *incy < 0 ) + { + y0 = ( ( double* )y ) + ( n0 - 1 ) * ( -( *incy ) ); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ( ( double* )y ); + incy0 = ( inc_t )( *incy ); + } + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + /* + Function pointer declarations for the function + that will be used by this API + */ + daxpbyv_ker_ft axpbyv_ker_ptr; // DAXPBYV + + // Pick the kernel based on the architecture ID + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + axpbyv_ker_ptr = bli_daxpbyv_zen_int_avx512; + + break; +#endif + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + axpbyv_ker_ptr = bli_daxpbyv_zen_int10; + + break; + default: + + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for daxpbyv + axpbyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_AXPBYV_KER, cntx); + } + + // Call the function based on the function pointer assigned above + axpbyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + (double *)alpha, + x0, incx0, + (double *)beta, + y0, incy0, + cntx + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +#ifdef BLIS_ENABLE_BLAS +void daxpby_ +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy +) +{ + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy ) ; +} +#endif + +//------------------------------------------------------------------------- + +void caxpby_blis_impl +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex *)alpha, *incx, *incy) + + /* Early exit in case n is 0, or alpha is 0 and beta is 1 */ + if ( ( *n <= 0 ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq1 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + dim_t n0; + scomplex *x0; + scomplex *y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + n0 = ( dim_t )( *n ); + + /* + If the input increments are negative, adjust the pointers so we can + use positive increments instead. + */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ( ( scomplex * )x ) + ( n0 - 1 ) * ( -( *incx ) ); + incx0 = ( inc_t )( *incx ); + } + else + { + x0 = ( ( scomplex* )x ); + incx0 = ( inc_t )( *incx ); + } + if ( *incy < 0 ) + { + y0 = ( ( scomplex* )y ) + ( n0 - 1 ) * ( -( *incy ) ); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ( ( scomplex* )y ); + incy0 = ( inc_t )( *incy ); + } + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + /* + Function pointer declarations for the function + that will be used by this API + */ + caxpbyv_ker_ft axpbyv_ker_ptr; // caxpbyV + + // Pick the kernel based on the architecture ID + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + axpbyv_ker_ptr = bli_caxpbyv_zen_int; + + break; + default: + + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for caxpbyv + axpbyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_SCOMPLEX, BLIS_AXPBYV_KER, cntx); + } + + // Call the function based on the function pointer assigned above + axpbyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex *)alpha, + x0, incx0, + (scomplex *)beta, + y0, incy0, + cntx + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +#ifdef BLIS_ENABLE_BLAS +void caxpby_ +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy +) +{ + caxpby_blis_impl( n, alpha, x, incx, beta, y, incy ) ; +} +#endif + +//------------------------------------------------------------------------- + +void zaxpby_blis_impl +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex *)alpha, *incx, *incy) + + /* Early exit in case n is 0, or alpha is 0 and beta is 1 */ + if ( ( *n <= 0 ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq1 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + dim_t n0; + dcomplex *x0; + dcomplex *y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + n0 = ( dim_t )( *n ); + + /* + If the input increments are negative, adjust the pointers so we can + use positive increments instead. + */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ( ( dcomplex * )x ) + ( n0 - 1 ) * ( -( *incx ) ); + incx0 = ( inc_t )( *incx ); + } + else + { + x0 = ( ( dcomplex* )x ); + incx0 = ( inc_t )( *incx ); + } + if ( *incy < 0 ) + { + y0 = ( ( dcomplex* )y ) + ( n0 - 1 ) * ( -( *incy ) ); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ( ( dcomplex* )y ); + incy0 = ( inc_t )( *incy ); + } + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + /* + Function pointer declarations for the function + that will be used by this API + */ + zaxpbyv_ker_ft axpbyv_ker_ptr; // zaxpbyV + + // Pick the kernel based on the architecture ID + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + axpbyv_ker_ptr = bli_zaxpbyv_zen_int; + + break; + default: + + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for zaxpbyv + axpbyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_AXPBYV_KER, cntx); + } + + // Call the function based on the function pointer assigned above + axpbyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex *)alpha, + x0, incx0, + (dcomplex *)beta, + y0, incy0, + cntx + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +#ifdef BLIS_ENABLE_BLAS +void zaxpby_ +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy +) +{ + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy ) ; +} +#endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index feffbc4955..98ad7a38b2 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -50,6 +50,18 @@ void PASTEF77S(ch,blasname) \ ftype* y, const f77_int* incy \ ) \ { \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ +\ + /* + BLAS exception: If the vector dimension is zero, or if alpha is zero, return early. + */ \ + if ((*n) <= 0 || PASTEMAC(ch, eq0)(*alpha)) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return; \ + } \ +\ dim_t n0; \ ftype* x0; \ ftype* y0; \ @@ -58,8 +70,7 @@ void PASTEF77S(ch,blasname) \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ +\ /* Convert/typecast negative values of n to zero. */ \ bli_convert_blas_dim1( *n, n0 ); \ \ diff --git a/frame/compat/bla_axpy.h b/frame/compat/bla_axpy.h index d83ce50ff7..b2db7842bd 100644 --- a/frame/compat/bla_axpy.h +++ b/frame/compat/bla_axpy.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c index 381cc10e67..a27765ca8a 100644 --- a/frame/compat/bla_axpy_amd.c +++ b/frame/compat/bla_axpy_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,7 +63,19 @@ void PASTEF77S(ch,blasname) \ ftype* y, const f77_int* incy \ ) \ { \ - dim_t n0; \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ +\ + /* + BLAS exception: If the vector dimension is zero, or if alpha is zero, return early. + */ \ + if ((*n) <= 0 || PASTEMAC(ch, eq0)(*alpha)) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return; \ + } \ +\ + dim_t n_elem; \ ftype* x0; \ ftype* y0; \ inc_t incx0; \ @@ -71,21 +83,20 @@ void PASTEF77S(ch,blasname) \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ +\ /* Convert/typecast negative values of n to zero. */ \ - bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *n, n_elem ); \ \ /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ \ - bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ - bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ + bli_convert_blas_incv( n_elem, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n_elem, (ftype*)y, *incy, y0, incy0 ); \ \ /* Call BLIS interface. */ \ PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ ( \ BLIS_NO_CONJUGATE, \ - n0, \ + n_elem, \ (ftype*)alpha, \ x0, incx0, \ y0, incy0, \ @@ -130,7 +141,6 @@ void saxpy_blis_impl if ((*n) <= 0 || PASTEMAC(s, eq0)(*alpha)) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; } @@ -201,6 +211,7 @@ void saxpy_blis_impl // Pick the kernel based on the architecture ID switch (id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) axpyv_ker_ptr = bli_saxpyv_zen_int_avx512; @@ -247,10 +258,12 @@ void saxpy_ float* y, const f77_int* incy ) { - saxpy_blis_impl( n, alpha, x, incx, y, incy ) ; + saxpy_blis_impl( n, alpha, x, incx, y, incy ) ; } #endif +//------------------------------------------------------------------------- + void daxpy_blis_impl ( const f77_int* n, @@ -259,14 +272,24 @@ void daxpy_blis_impl double* y, const f77_int* incy ) { + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) + + /* + BLAS exception: If the vector dimension is zero, or if alpha is zero, return early. + */ + if ((*n) <= 0 || PASTEMAC(d, eq0)(*alpha)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + dim_t n_elem; double* x0; double* y0; inc_t incx0; inc_t incy0; - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) /* Initialize BLIS. */ // bli_init_auto(); @@ -274,13 +297,6 @@ void daxpy_blis_impl if ( *n < 0 ) n_elem = ( dim_t )0; else n_elem = ( dim_t )(*n); - // BLAS exception to return early when n <= 0 or alpha is 0.0 - if(*n <= 0 || bli_deq0(*alpha)) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ if ( *incx < 0 ) @@ -327,11 +343,11 @@ void daxpy_blis_impl // Pick the kernel based on the architecture ID switch (arch_id_local) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; - - break; + axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; + break; #endif case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: @@ -386,33 +402,44 @@ void daxpy_blis_impl ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; + #ifdef BLIS_ENABLE_OPENMP } _Pragma("omp parallel num_threads(nt)") { - dim_t start, length; + dim_t start, end, length; + thrinfo_t thrinfo_vec; - // Get the thread ID - dim_t thread_id = omp_get_thread_num(); + // The block size is the minimum factor, whose multiple will ensure that only + // the vector code section is executed. Furthermore, for double datatype it corresponds + // to one cacheline size. + dim_t block_size = 8; // Get the actual number of threads spawned - dim_t nt_use = omp_get_num_threads(); + thrinfo_vec.n_way = omp_get_num_threads(); + + // Get the thread ID + thrinfo_vec.work_id = omp_get_thread_num(); /* Calculate the compute range for the current thread based on the actual number of threads spawned */ - bli_thread_vector_partition + + bli_thread_range_sub ( + &thrinfo_vec, n_elem, - nt_use, - &start, &length, - thread_id + block_size, + FALSE, + &start, + &end ); + length = end - start; + // Adjust the local pointer for computation double *x_thread_local = x0 + (start * incx0); double *y_thread_local = y0 + (start * incy0); @@ -444,9 +471,12 @@ void daxpy_ double* y, const f77_int* incy ) { - daxpy_blis_impl( n, alpha, x, incx, y, incy ) ; + daxpy_blis_impl( n, alpha, x, incx, y, incy ) ; } #endif + +//------------------------------------------------------------------------- + void caxpy_blis_impl ( const f77_int* n, @@ -455,24 +485,34 @@ void caxpy_blis_impl scomplex* y, const f77_int* incy ) { - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) + + /* + BLAS exception: If the vector dimension is zero, or if alpha is zero, return early. + */ + if ((*n) <= 0 || PASTEMAC(c, eq0)(*alpha)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + dim_t n_elem; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n_elem = ( dim_t )0; + else n_elem = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) { /* The semantics of negative stride in BLAS are that the vector operand be traversed in reverse order. (Another way to think @@ -486,58 +526,59 @@ void caxpy_blis_impl BLIS, if this backwards traversal is desired, the caller *must* pass in the address to the (n-1)th (i.e., the bottom-most or right-most) element along with a negative stride. */ - x0 = ((scomplex*)x) + (n0-1)*(-*incx); + x0 = ((scomplex*)x) + (n_elem-1)*(-*incx); incx0 = ( inc_t )(*incx); } - else + else { x0 = ((scomplex*)x); incx0 = ( inc_t )(*incx); } - if ( *incy < 0 ) + if ( *incy < 0 ) { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); + y0 = ((scomplex*)y) + (n_elem-1)*(-*incy); incy0 = ( inc_t )(*incy); } - else + else { y0 = ((scomplex*)y); incy0 = ( inc_t )(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) - { - bli_caxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); + // This function is invoked on all architectures including 'generic'. + // Non-AVX2+FMA3 platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx2fma3_supported() == TRUE) + { + bli_caxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n_elem, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); } + #ifdef BLIS_ENABLE_BLAS void caxpy_ ( @@ -547,9 +588,12 @@ void caxpy_ scomplex* y, const f77_int* incy ) { - caxpy_blis_impl( n, alpha, x, incx, y, incy ) ; + caxpy_blis_impl( n, alpha, x, incx, y, incy ) ; } #endif + +//------------------------------------------------------------------------- + void zaxpy_blis_impl ( const f77_int* n, @@ -558,25 +602,36 @@ void zaxpy_blis_impl dcomplex* y, const f77_int* incy ) { - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) + /* + BLAS exception: If the vector dimension is zero, or if alpha is zero, return early. + */ + if ((*n) <= 0 || PASTEMAC(z, eq0)(*alpha)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } - /* Initialize BLIS. */ - // bli_init_auto(); + dim_t n_elem; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); + /* Initialize BLIS. */ + // bli_init_auto(); - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) + // Convert/typecast negative values of n to zero. + if ( *n < 0 ) + n_elem = ( dim_t )0; + else + n_elem = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) { /* The semantics of negative stride in BLAS are that the vector operand be traversed in reverse order. (Another way to think @@ -590,58 +645,157 @@ void zaxpy_blis_impl BLIS, if this backwards traversal is desired, the caller *must* pass in the address to the (n-1)th (i.e., the bottom-most or right-most) element along with a negative stride. */ - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + x0 = ( (dcomplex*)x ) + ( n_elem - 1) * ( -*incx ); incx0 = ( inc_t )(*incx); } - else + else { - x0 = ((dcomplex*)x); + x0 = ( (dcomplex*)x ); incx0 = ( inc_t )(*incx); } - if ( *incy < 0 ) + if ( *incy < 0 ) { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + y0 = ( (dcomplex*)y ) + ( n_elem - 1 ) * ( -*incy ); incy0 = ( inc_t )(*incy); } - else + else { - y0 = ((dcomplex*)y); + y0 = ( (dcomplex*)y ); incy0 = ( inc_t )(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) - { - bli_zaxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); + // Definition of function pointer + zaxpyv_ker_ft axpyv_ker_ptr; + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t arch_id_local = bli_arch_query_id(); + + // Pick the kernel based on the architecture ID + switch (arch_id_local) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: + +#if defined(BLIS_KERNELS_ZEN4) + // AVX512 Kernel + axpyv_ker_ptr = bli_zaxpyv_zen_int_avx512; + break; +#endif + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + + // AVX2 Kernel + axpyv_ker_ptr = bli_zaxpyv_zen_int5; + break; + + default: + + // Query the context + cntx = bli_gks_query_cntx(); + + // Query the function pointer using the context + axpyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx); + } + +#ifdef BLIS_ENABLE_OPENMP + /* + Initializing the number of thread to one + to avoid compiler warnings + */ + dim_t nt = 1; + + /* + For the given problem size and architecture, the function + returns the optimum number of threads with AOCL dynamic enabled + else it returns the number of threads requested by the user. + */ + + bli_nthreads_l1 + ( + BLIS_AXPYV_KER, + BLIS_DCOMPLEX, + BLIS_DCOMPLEX, + arch_id_local, + n_elem, + &nt + ); + + if (nt == 1) + { +#endif + + axpyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n_elem, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + +#ifdef BLIS_ENABLE_OPENMP + } + + _Pragma("omp parallel num_threads(nt)") + { + dim_t start, end, length; + thrinfo_t thread; + + // The factor by which the size should be a multiple during thread partition. The main loop of the kernel can handle 32 elements at a time hence 32 is selected for block_size. + dim_t block_size = 32; + + // Get the thread ID + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread ); + + // Get the actual number of threads spawned + bli_thrinfo_set_n_way( omp_get_num_threads(), &thread ); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ + + bli_thread_range_sub + ( + &thread, + n_elem, + block_size, + FALSE, + &start, + &end + ); + + length = end - start; + + // Adjust the local pointer for computation + dcomplex* x_thread_local = x0 + (start * incx0); + dcomplex* y_thread_local = y0 + (start * incy0); + + // Invoke the function based on the kernel function pointer + axpyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + length, + (dcomplex*)alpha, + x_thread_local, incx0, + y_thread_local, incy0, + cntx + ); + } +#endif // BLIS_ENABLE_OPENMP + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); } + #ifdef BLIS_ENABLE_BLAS void zaxpy_ ( @@ -651,8 +805,6 @@ void zaxpy_ dcomplex* y, const f77_int* incy ) { - zaxpy_blis_impl( n, alpha, x, incx, y, incy ) ; + zaxpy_blis_impl( n, alpha, x, incx, y, incy ) ; } - - #endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index f250d46919..f23358440b 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/frame/compat/bla_copy.h b/frame/compat/bla_copy.h index 14634096eb..fa1b3448f5 100644 --- a/frame/compat/bla_copy.h +++ b/frame/compat/bla_copy.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c index bf45f5f823..92a741dd9d 100644 --- a/frame/compat/bla_copy_amd.c +++ b/frame/compat/bla_copy_amd.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -43,11 +43,11 @@ #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ void PASTEF77S(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ { \ dim_t n0; \ ftype* x0; \ @@ -66,39 +66,41 @@ void PASTEF77S(ch,blasname) \ \ /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ \ - bli_convert_blas_incv(n0, (ftype*)x, *incx, x0, incx0); \ - bli_convert_blas_incv(n0, (ftype*)y, *incy, y0, incy0); \ - \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch, blisname, BLIS_TAPI_EX_SUF) \ - (\ - BLIS_NO_CONJUGATE, \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - \ + bli_convert_blas_incv(n0, (ftype*)x, *incx, x0, incx0); \ + bli_convert_blas_incv(n0, (ftype*)y, *incy, y0, incy0); \ \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch, blisname, BLIS_TAPI_EX_SUF) \ + (\ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + \ \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ }\ \ IF_BLIS_ENABLE_BLAS(\ void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ { \ - PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ + PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } \ ) +// --------------------------------------------------------- + void scopy_blis_impl ( const f77_int* n, @@ -114,7 +116,9 @@ void scopy_blis_impl AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) + /* Initialize BLIS. */ + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ @@ -162,37 +166,50 @@ void scopy_blis_impl incy0 = (inc_t)(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) - { - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + // Function pointer declaration for the function + // that will be used by this API + scopyv_ker_ft copyv_ker_ptr; // SCOPYV + + // Pick the kernel based on the architecture ID + switch (id) { - PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + copyv_ker_ptr = bli_scopyv_zen4_asm_avx512; + break; +#endif + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + copyv_ker_ptr = bli_scopyv_zen_int; + break; + default: + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + // Query the context for the kernel function pointers for scopyv + copyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_FLOAT, BLIS_COPYV_KER, cntx); } + copyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + cntx + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ // bli_finalize_auto(); } + #ifdef BLIS_ENABLE_BLAS void scopy_ ( @@ -204,6 +221,9 @@ void scopy_ scopy_blis_impl( n, x, incx, y, incy ); } #endif + +// -------------------------------------------------------------------- + void dcopy_blis_impl ( const f77_int* n, @@ -220,7 +240,7 @@ void dcopy_blis_impl AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) /* Initialize BLIS. */ -// bli_init_auto(); + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ if (*n < 0) @@ -267,48 +287,349 @@ void dcopy_blis_impl incy0 = (inc_t)(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + // Function pointer declaration for the function + // that will be used by this API + dcopyv_ker_ft copyv_ker_ptr; // DCOPYV + + // Pick the kernel based on the architecture ID + switch (id) { - /* Call BLIS kernel */ - bli_dcopyv_zen_int + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + // For Zen4 and Zen5, kernel implemented in AVX512 is used + copyv_ker_ptr = bli_dcopyv_zen4_asm_avx512; + break; +#endif + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + // For Zen1, Zen2 and Zen3 architectures, kernel implemented in AVX2 is used. + copyv_ker_ptr = bli_dcopyv_zen_int; + break; + default: + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + // Query the context for the kernel function pointers for dcopyv + copyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_COPYV_KER, cntx); + } + +#ifdef BLIS_ENABLE_OPENMP + /* + Initializing the number of thread to one + to avoid compiler warnings + */ + dim_t nt = 1; + + /* + For the given problem size and architecture, the function + returns the optimum number of threads with AOCL dynamic enabled + else it returns the number of threads requested by the user. + */ + bli_nthreads_l1 + ( + BLIS_COPYV_KER, + BLIS_DOUBLE, + BLIS_DOUBLE, + id, + n0, + &nt + ); + + /* + If the number of optimum threads is 1, the OpenMP overhead + is avoided by calling the function directly + */ + if (nt == 1) + { +#endif + + copyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + +#ifdef BLIS_ENABLE_OPENMP + } + + _Pragma("omp parallel num_threads(nt)") + { + dim_t start, end, length; + thrinfo_t thread; + + // The factor by which the size should be a multiple during thread partition. + // The main loop of the kernel can handle 32 elements at a time hence 32 is selected for block_size. + dim_t block_size = 32; + + // Get the thread ID + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread ); + + // Get the actual number of threads spawned + bli_thrinfo_set_n_way( omp_get_num_threads(), &thread ); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ + + bli_thread_range_sub ( - BLIS_NO_CONJUGATE, + &thread, n0, - x0, incx0, - y0, incy0, - NULL + block_size, + FALSE, + &start, + &end ); - } - else - { - PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) + + length = end - start; + + // Adjust the local pointer for computation + double *x_thread_local = x0 + (start * incx0); + double *y_thread_local = y0 + (start * incy0); + + // Invoke the function based on the kernel function pointer + copyv_ker_ptr ( BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL + length, + x_thread_local, incx0, + y_thread_local, incy0, + cntx ); } +#endif // BLIS_ENABLE_OPENMP AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ // bli_finalize_auto(); + } #ifdef BLIS_ENABLE_BLAS + void dcopy_ ( const f77_int* n, const double* x, const f77_int* incx, double* y, const f77_int* incy ) + { dcopy_blis_impl( n, x, incx, y, incy ); } #endif -INSERT_GENTFUNC_BLAS_CZ(copy, copyv) + +// --------------------------------------------------------------- + +void zcopy_blis_impl +( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy +) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy) + + /* Initialize BLIS. */ + +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (dcomplex*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (dcomplex*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + // Function pointer declaration for the function + // that will be used by this API + zcopyv_ker_ft copyv_ker_ptr; // ZCOPYV + + // Pick the kernel based on the architecture ID + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + // For Zen4 and Zen5 architecture, kernel implemented in AVX512 is used + copyv_ker_ptr = bli_zcopyv_zen4_asm_avx512; + break; +#endif + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + // For Zen1, Zen2 and Zen3 architectures, kernel implemented in AVX2 is used. + copyv_ker_ptr = bli_zcopyv_zen_int; + break; + default: + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + // Query the context for the kernel function pointers for zcopyv + copyv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_COPYV_KER, cntx); + } + +#ifdef BLIS_ENABLE_OPENMP + /* + Initializing the number of thread to one + to avoid compiler warnings + */ + dim_t nt = 1; + + /* + For the given problem size and architecture, the function + returns the optimum number of threads with AOCL dynamic enabled + else it returns the number of threads requested by the user. + */ + bli_nthreads_l1 + ( + BLIS_COPYV_KER, + BLIS_DCOMPLEX, + BLIS_DCOMPLEX, + id, + n0, + &nt + ); + + /* + If the number of optimum threads is 1, the OpenMP overhead + is avoided by calling the function directly + */ + if (nt == 1) + { +#endif + + copyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + cntx + ); + +#ifdef BLIS_ENABLE_OPENMP + } + + else + { + _Pragma("omp parallel num_threads(nt)") + { + dim_t start, length; + + // Get the thread ID + dim_t thread_id = omp_get_thread_num(); + + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ + bli_thread_vector_partition + ( + n0, + nt_use, + &start, &length, + thread_id + ); + + // Adjust the local pointer for computation + dcomplex *x_thread_local = x0 + (start * incx0); + dcomplex *y_thread_local = y0 + (start * incy0); + + // Invoke the function based on the kernel function pointer + copyv_ker_ptr + ( + BLIS_NO_CONJUGATE, + length, + x_thread_local, incx0, + y_thread_local, incy0, + cntx + ); + } + } + +#endif + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +#ifdef BLIS_ENABLE_BLAS +void zcopy_ +( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy +) +{ + zcopy_blis_impl( n, x, incx, y, incy ); +} +#endif + +INSERT_GENTFUNC_BLAS_C(copy, copyv) diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 76c2cdf48d..7c1f125f28 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -50,7 +50,7 @@ ftype PASTEF772S(ch,blasname,chc) \ ) \ { \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *MKSTR(blis_conjx), *n, *incx, *incy); \ dim_t n0; \ ftype* x0; \ ftype* y0; \ @@ -119,7 +119,7 @@ void PASTEF772S(ch,blasname,chc) \ ) \ { \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *MKSTR(blis_conjx), *n, *incx, *incy); \ dim_t n0; \ ftype* x0; \ ftype* y0; \ @@ -229,7 +229,7 @@ double PASTEF77S(d,sdot) dim_t i; AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', 'N', *n, *incx, *incy); /* Initialization of BLIS is not required. */ /* Convert/typecast negative values of n to zero. */ diff --git a/frame/compat/bla_dot.h b/frame/compat/bla_dot.h index c06dd69334..7fc599df6b 100644 --- a/frame/compat/bla_dot.h +++ b/frame/compat/bla_dot.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c index 4e7e5ca907..92d773410a 100644 --- a/frame/compat/bla_dot_amd.c +++ b/frame/compat/bla_dot_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,7 +61,7 @@ ftype PASTEF772S(ch,blasname,chc) \ ) \ { \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *MKSTR(blis_conjx), *n, *incx, *incy); \ dim_t n0; \ ftype* x0; \ ftype* y0; \ @@ -120,7 +120,7 @@ float sdot_blis_impl ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', 'N', *n, *incx, *incy); dim_t n0; float* x0; float* y0; @@ -198,6 +198,7 @@ float sdot_blis_impl // Pick the kernel based on the architecture ID switch (arch_id) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) @@ -258,7 +259,7 @@ double ddot_blis_impl ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', 'N', *n, *incx, *incy); dim_t n_elem; double* x0; double* y0; @@ -328,6 +329,7 @@ double ddot_blis_impl // Pick the kernel based on the architecture ID switch (arch_id_local) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) @@ -468,8 +470,19 @@ double ddot_blis_impl } else { - nt = 1; - rho_temp = ρ + dotv_ker_ptr + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n_elem, + x0, incx0, + y0, incy0, + &rho, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return rho; } _Pragma("omp parallel num_threads(nt)") @@ -553,7 +566,7 @@ scomplex cdotu_blis_impl ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', 'N', *n, *incx, *incy); dim_t n0; scomplex* x0; scomplex* y0; @@ -562,11 +575,11 @@ scomplex cdotu_blis_impl scomplex rho; /* Initialize BLIS. */ -// bli_init_auto(); + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); + else n0 = ( dim_t )(*n); /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ @@ -641,7 +654,7 @@ scomplex cdotu_blis_impl } /* Finalize BLIS. */ -// bli_finalize_auto(); + // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; } @@ -670,15 +683,17 @@ dcomplex zdotu_blis_impl inc_t incy0; dcomplex rho; + PASTEMAC(z,set0s)( rho ); // Initializing rho to 0. + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', 'N', *n, *incx, *incy); /* Initialize BLIS. */ -// bli_init_auto(); + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); + else n0 = ( dim_t )(*n); /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ @@ -720,40 +735,221 @@ dcomplex zdotu_blis_impl incy0 = ( inc_t )(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t arch_id_local = bli_arch_query_id(); + zdotv_ker_ft zdotv_ker_ptr; + + switch ( arch_id_local ) { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + zdotv_ker_ptr = bli_zdotv_zen_int_avx512; + break; +#endif + + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + zdotv_ker_ptr = bli_zdotv_zen_int5; + break; + + default: + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for zdotv + zdotv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_DOTV_KER, cntx); + break; + } + +#ifdef BLIS_ENABLE_OPENMP + // Initialize number of threads to one. + dim_t nt = 1; + + bli_nthreads_l1 + ( + BLIS_DOTV_KER, + BLIS_DCOMPLEX, + BLIS_DCOMPLEX, + arch_id_local, + n0, + &nt + ); + + /* + If the number of optimum threads is 1, the OpenMP overhead + is avoided by calling the function directly + */ + if (nt == 1) + { +#endif + zdotv_ker_ptr ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + cntx ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + + return rho; +#ifdef BLIS_ENABLE_OPENMP + } + + /* + Here we know that more than one thread needs to be spawned. + + In such a case, each thread will need its own rho value to + do the accumulation. These temporary rho's will be accumulated + in the end. + */ + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + dcomplex *rho_temp = NULL; + + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_pba_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + mem_t mem_buf_rho; + mem_buf_rho.pblk.buf = NULL; + mem_buf_rho.pblk.block_size = 0; + mem_buf_rho.buf_type = 0; + mem_buf_rho.size = 0; + mem_buf_rho.pool = NULL; + + /* + In order to get the buffer from pool via rntm access to + memory broker is needed.Following are initializations + for rntm. + */ + bli_rntm_set_num_threads_only(1, &rntm_l); + bli_pba_rntm_set_pba(&rntm_l); + + // Calculate the size required for rho buffer. + size_t buffer_size = nt * sizeof(dcomplex); + +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_zdotu(): get mem pool block\n"); +#endif + + /* + Acquire a buffer (nt * size(dcomplex)) from the memory broker + and save the associated mem_t entry to mem_buf_rho. + */ + bli_pba_acquire_m + ( + &rntm_l, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &mem_buf_rho + ); + + /* Continue if rho buffer memory is allocated*/ + if ( bli_mem_is_alloc( &mem_buf_rho ) ) + { + rho_temp = bli_mem_buffer( &mem_buf_rho ); + + /* + Initializing rho_temp buffer to zeros. + + This is done to handle cases when the + number of threads launched is not equal + to the number of threads requested. In + such cases, the garbage value in the created + buffer will not be overwritten by valid values. + + This will ensure that garbage values will + not get accumulated with the final result. + */ + for ( dim_t i = 0; i < nt; ++i ) + PASTEMAC(z,set0s)( *(rho_temp + i) ); } else { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + zdotv_ker_ptr ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + cntx ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return rho; + } + + _Pragma("omp parallel num_threads(nt)") + { + dim_t start, length; + + // Get the thread ID + dim_t thread_id = omp_get_thread_num(); + + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ + bli_thread_vector_partition + ( + n0, + nt_use, + &start, &length, + thread_id + ); + + // Adjust the local pointer for computation + dcomplex *x_thread_local = x0 + (start * incx0); + dcomplex *y_thread_local = y0 + (start * incy0); + + // Invoke the function based on the kernel function pointer + zdotv_ker_ptr + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + length, + x_thread_local, incx0, + y_thread_local, incy0, + rho_temp + thread_id, + cntx + ); + } + + /* + Accumulate the values in rho_temp only when mem is allocated. + When the memory cannot be allocated rho_temp will point to + rho + */ + if ( bli_mem_is_alloc( &mem_buf_rho ) ) + { + // Accumulating the nt thread outputs to rho + for ( dim_t i = 0; i < nt; ++i ) + PASTEMAC(z,adds)( *(rho_temp + i), rho ); + + // Releasing the allocated memory if it was allocated + bli_pba_release( &rntm_l, &mem_buf_rho ); } +#endif /* Finalize BLIS. */ -// bli_finalize_auto(); + // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); @@ -785,14 +981,14 @@ scomplex cdotc_blis_impl scomplex rho; AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', 'C', *n, *incx, *incy); /* Initialize BLIS. */ -// bli_init_auto(); + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); + else n0 = ( dim_t )(*n); /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ @@ -867,7 +1063,7 @@ scomplex cdotc_blis_impl } /* Finalize BLIS. */ -// bli_finalize_auto(); + // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; @@ -891,7 +1087,7 @@ dcomplex zdotc_blis_impl ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', 'C', *n, *incx, *incy); dim_t n0; dcomplex* x0; dcomplex* y0; @@ -899,12 +1095,14 @@ dcomplex zdotc_blis_impl inc_t incy0; dcomplex rho; + PASTEMAC(z,set0s)( rho ); // Initializing rho to 0. + /* Initialize BLIS. */ -// bli_init_auto(); + // bli_init_auto(); /* Convert/typecast negative values of n to zero. */ if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); + else n0 = ( dim_t )(*n); /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ @@ -946,40 +1144,223 @@ dcomplex zdotc_blis_impl incy0 = ( inc_t )(*incy); } - // This function is invoked on all architectures including 'generic'. - // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) + cntx_t *cntx = NULL; + + // Query the architecture ID + arch_t arch_id_local = bli_arch_query_id(); + zdotv_ker_ft zdotv_ker_ptr; + + switch ( arch_id_local ) { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + // Currently only the AVX512 intrinsic kernel is enabled. + zdotv_ker_ptr = bli_zdotv_zen_int_avx512; + // zdotv_ker_ptr = bli_zdotv_zen4_asm_avx512; + break; +#endif + + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + zdotv_ker_ptr = bli_zdotv_zen_int5; + break; + + default: + // For non-Zen architectures, query the context + cntx = bli_gks_query_cntx(); + + // Query the context for the kernel function pointers for zdotv + zdotv_ker_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_DOTV_KER, cntx); + break; + } + +#ifdef BLIS_ENABLE_OPENMP + // Initialize number of threads to one. + dim_t nt = 1; + + bli_nthreads_l1 + ( + BLIS_DOTV_KER, + BLIS_DCOMPLEX, + BLIS_DCOMPLEX, + arch_id_local, + n0, + &nt + ); + + /* + If the number of optimum threads is 1, the OpenMP overhead + is avoided by calling the function directly + */ + if (nt == 1) + { +#endif + zdotv_ker_ptr ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + cntx ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + + return rho; +#ifdef BLIS_ENABLE_OPENMP + } + + /* + Here we know that more than one thread needs to be spawned. + + In such a case, each thread will need its own rho value to + do the accumulation. These temporary rho's will be accumulated + in the end. + */ + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + dcomplex *rho_temp = NULL; + + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_pba_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + mem_t mem_buf_rho; + mem_buf_rho.pblk.buf = NULL; + mem_buf_rho.pblk.block_size = 0; + mem_buf_rho.buf_type = 0; + mem_buf_rho.size = 0; + mem_buf_rho.pool = NULL; + + /* + In order to get the buffer from pool via rntm access to + memory broker is needed.Following are initializations + for rntm. + */ + bli_rntm_set_num_threads_only(1, &rntm_l); + bli_pba_rntm_set_pba(&rntm_l); + + // Calculate the size required for rho buffer. + size_t buffer_size = nt * sizeof(dcomplex); + +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_zdotc(): get mem pool block\n"); +#endif + + /* + Acquire a buffer (nt * size(dcomplex)) from the memory broker + and save the associated mem_t entry to mem_buf_rho. + */ + bli_pba_acquire_m + ( + &rntm_l, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &mem_buf_rho + ); + + /* Continue if rho buffer memory is allocated*/ + if ( bli_mem_is_alloc( &mem_buf_rho ) ) + { + rho_temp = bli_mem_buffer( &mem_buf_rho ); + + /* + Initializing rho_temp buffer to zeros. + + This is done to handle cases when the + number of threads launched is not equal + to the number of threads requested. In + such cases, the garbage value in the created + buffer will not be overwritten by valid values. + + This will ensure that garbage values will + not get accumulated with the final result. + */ + for ( dim_t i = 0; i < nt; ++i ) + PASTEMAC(z,set0s)( *(rho_temp + i) ); } else { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + zdotv_ker_ptr ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + cntx ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return rho; } + _Pragma("omp parallel num_threads(nt)") + { + dim_t start, length; + + // Get the thread ID + dim_t thread_id = omp_get_thread_num(); + + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ + bli_thread_vector_partition + ( + n0, + nt_use, + &start, &length, + thread_id + ); + + // Adjust the local pointer for computation + dcomplex *x_thread_local = x0 + (start * incx0); + dcomplex *y_thread_local = y0 + (start * incy0); + + // Invoke the function based on the kernel function pointer + zdotv_ker_ptr + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + length, + x_thread_local, incx0, + y_thread_local, incy0, + rho_temp + thread_id, + cntx + ); + } + + /* + Accumulate the values in rho_temp only when mem is allocated. + When the memory cannot be allocated rho_temp will point to + rho + */ + if ( bli_mem_is_alloc( &mem_buf_rho ) ) + { + // Accumulating the nt thread outputs to rho + for ( dim_t i = 0; i < nt; ++i ) + PASTEMAC(z,adds)( *(rho_temp + i), rho ); + + // Releasing the allocated memory if it was allocated + bli_pba_release( &rntm_l, &mem_buf_rho ); + } +#endif + /* Finalize BLIS. */ -// bli_finalize_auto(); + // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); @@ -1011,7 +1392,7 @@ void PASTEF772S(ch,blasname,chc) \ ) \ { \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *MKSTR(blis_conjx), *n, *incx, *incy); \ dim_t n0; \ ftype* x0; \ ftype* y0; \ @@ -1120,7 +1501,7 @@ double PASTEF77S(d,sdot) dim_t i; AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', 'N', *n, *incx, *incy); /* Initialization of BLIS is not required. */ /* Convert/typecast negative values of n to zero. */ diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 5bdbc392c4..13f33be400 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -415,10 +415,9 @@ void PASTEF77(ch,blasname) \ ) #endif -#ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm,gemm ) -void dzgemm_ +void dzgemm_blis_impl ( const f77_char* transa, const f77_char* transb, @@ -539,13 +538,44 @@ void dzgemm_ bli_obj_set_conjtrans( blis_transb, &bo ); // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + //bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Mix of real and complex matrix data types, so assuming + induced methods will not be available */ + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ - +#ifdef BLIS_ENABLE_BLAS +void dzgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const double* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); +} #endif diff --git a/frame/compat/bla_gemm3m.c b/frame/compat/bla_gemm3m.c index 7b476991a5..3cb3ace8b6 100644 --- a/frame/compat/bla_gemm3m.c +++ b/frame/compat/bla_gemm3m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,79 +58,95 @@ void PASTEF77S(ch,blasname) \ ftype* c, const f77_int* ldc \ ) \ { \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ -\ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ -\ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ -\ - /* Quick return if possible. */ \ - if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ - && PASTEMAC(ch,eq1)( *beta ) )) \ - { \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ - return; \ - } \ -\ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ -\ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ -\ - /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ -\ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - blis_transa, \ - blis_transb, \ - m0, \ - n0, \ - k0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a, \ - (ftype*)b, rs_b, cs_b, \ - (ftype*)beta, \ - (ftype*)c, rs_c, cs_c, \ - NULL, \ - NULL \ - ); \ -\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ +\ + /* As a placeholder, invoke 1m since BLIS does no longer contains an + official 3m implementation. Note that we do this by inlining an + abbreviated version of bli_gemm_ex() so that we can bypass + consideration of sup, which doesn't make sense in this context. */ \ + { \ + cntx_t* cntx = bli_gks_query_ind_cntx( BLIS_1M, dt ); \ +\ + rntm_t rntm_l; \ + rntm_t* rntm = &rntm_l; \ + bli_rntm_init_from_global( rntm ); \ +\ + /* Note that we MUST disable sup handling since it could redirect + execution for some problem sizes to a non-3m implementation. */ \ + bli_rntm_disable_l3_sup( rntm ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + cntx, \ + rntm \ + ); \ + } \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } \ IF_BLIS_ENABLE_BLAS(\ void PASTEF77(ch,blasname) \ @@ -147,7 +163,7 @@ void PASTEF77(ch,blasname) \ ftype* c, const f77_int* ldc \ ) \ { \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ } \ ) @@ -170,94 +186,114 @@ void PASTEF77S(ch,blasname) \ ftype* c, const f77_int* ldc \ ) \ { \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ -\ - /* Initialize BLIS. */ \ - bli_init_auto(); \ -\ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ -\ - /* Quick return if possible. */ \ - if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ - && PASTEMAC(ch,eq1)( *beta ) )) \ - { \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ - return; \ - } \ -\ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ -\ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ -\ - /* Set the row and column strides of the matrix operands. */ \ - const inc_t rs_a = 1; \ - const inc_t cs_a = *lda; \ - const inc_t rs_b = 1; \ - const inc_t cs_b = *ldb; \ - const inc_t rs_c = 1; \ - const inc_t cs_c = *ldc; \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m0_a, n0_a; \ - dim_t m0_b, n0_b; \ -\ - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ -\ - bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ -\ - bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_conjtrans( blis_transa, &ao ); \ - bli_obj_set_conjtrans( blis_transb, &bo ); \ -\ - PASTEMAC(blisname,ind) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - NULL, \ - NULL \ - ); \ -\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + /* As a placeholder, invoke 1m since BLIS does no longer contains an + official 3m implementation. Note that we do this by inlining an + abbreviated version of bli_gemm_ex() so that we can bypass + consideration of sup, which doesn't make sense in this context. */ \ + { \ + cntx_t* cntx = bli_gks_query_ind_cntx( BLIS_1M, dt ); \ +\ + rntm_t rntm_l; \ + rntm_t* rntm = &rntm_l; \ + bli_rntm_init_from_global( &rntm_l ); \ +\ + /* This is probably not needed given that we performed BLAS-style + parameter checking above, but bli_gemm_check() is normally called + in the normal course of bli_gemm_ex(). */ \ + if ( bli_error_checking_is_enabled() ) \ + bli_gemm_check( &alphao, &ao, &bo, &betao, &co, cntx ); \ +\ + PASTEMAC(blisname,_front) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + cntx, \ + rntm, \ + NULL \ + ); \ + } \ +\ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } \ IF_BLIS_ENABLE_BLAS(\ void PASTEF77(ch,blasname) \ @@ -274,7 +310,7 @@ void PASTEF77(ch,blasname) \ ftype* c, const f77_int* ldc \ ) \ { \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ } \ ) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index e858bf6147..a34305c9c2 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,7 +43,7 @@ #define GEMM_BLIS_IMPL(ch, blasname) \ PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ arch_t id = bli_arch_query_id(); \ - if (id == BLIS_ARCH_ZEN4) \ + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) \ { \ bli_zero_zmm(); \ } \ @@ -602,8 +602,8 @@ void dgemm_blis_impl c, *ldc ); } -#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) - else if( arch_id == BLIS_ARCH_ZEN4 ) +#if defined(BLIS_FAMILY_ZEN5) || defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) + else if( arch_id == BLIS_ARCH_ZEN5 || arch_id == BLIS_ARCH_ZEN4 ) { ret = bli_dgemm_24x8_avx512_k1_nn( m0, n0, k0, (double*)alpha, @@ -839,8 +839,20 @@ void dgemm_blis_impl } // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + //bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL ); /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ /* ( */ @@ -876,7 +888,7 @@ void dgemm_ dgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } @@ -1116,23 +1128,94 @@ void zgemm_blis_impl - The input constraints are that k should be 1, and transa and transb should be N and N respectively. */ - if( ( k0 == 1 ) && bli_is_notrans( blis_transa ) && bli_is_notrans( blis_transb ) ) + if( ( k0 == 1 ) && bli_is_notrans( blis_transa ) && + bli_is_notrans( blis_transb ) ) { - bli_zgemm_4x4_avx2_k1_nn - ( - m0, n0, k0, - (dcomplex*)alpha, - (dcomplex*)a, *lda, - (dcomplex*)b, *ldb, - (dcomplex*)beta, - c, *ldc - ); + err_t ret = BLIS_FAILURE; + arch_t arch_id = bli_arch_query_id(); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - return; + if( arch_id == BLIS_ARCH_ZEN || arch_id == BLIS_ARCH_ZEN2 || + arch_id == BLIS_ARCH_ZEN3 ) + { + ret = bli_zgemm_4x4_avx2_k1_nn + ( + m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc + ); + } + +#if defined(BLIS_KERNELS_ZEN4) + else if ( arch_id == BLIS_ARCH_ZEN4 ) + { + // Redirecting to AVX-2 kernel if load direction( m0 ) is < 30. + // This holds true irrespective of the broadcast direction( n0 ) + if( m0 < 30 ) + { + ret = bli_zgemm_4x4_avx2_k1_nn + ( + m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc + ); + } + else + { + ret = bli_zgemm_16x4_avx512_k1_nn + ( + m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc + ); + } + } + else if ( arch_id == BLIS_ARCH_ZEN5 ) + { + // Redirecting to AVX-2 kernel if the dimensions are < 30 + // ( i.e, small or tiny sizes ), or if the load directon( m0 ) < 10 + if( ( m0 < 30 && n0 < 30 ) || m0 < 10 ) + { + ret = bli_zgemm_4x4_avx2_k1_nn + ( + m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc + ); + } + else + { + ret = bli_zgemm_16x4_avx512_k1_nn + ( + m0, n0, k0, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, + c, *ldc + ); + } + } +#endif + if( ret == BLIS_SUCCESS ) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + return; + } } const num_t dt = BLIS_DCOMPLEX; @@ -1212,7 +1295,32 @@ void zgemm_blis_impl } // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + //bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* As each matrix operand has a complex storage datatype, try to get an + induced method (if one is available and enabled). NOTE: Allowing + precisions to vary while using 1m, which is what we do here, is unique + to gemm; other level-3 operations use 1m only if all storage datatypes + are equal (and they ignore the computation precision). */ + + /* Find the highest priority induced method that is both enabled and + available for the current operation. (If an induced method is + available but not enabled, or simply unavailable, BLIS_NAT will + be returned here.) */ + im = bli_gemmind_find_avail( dt ); + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); @@ -1237,7 +1345,7 @@ void zgemm_ zgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } @@ -1367,7 +1475,23 @@ void dzgemm_blis_impl bli_obj_set_conjtrans( blis_transb, &bo ); // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + //bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Mix of real and complex matrix data types, so assuming + induced methods will not be available */ + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(gemm,_front)( &alphao, &ao, &bo, &betao, &co, cntx, &rntm_l, NULL ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); @@ -1392,7 +1516,7 @@ void dzgemm_ dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } diff --git a/frame/compat/bla_gemm_compute.c b/frame/compat/bla_gemm_compute.c index 8d9f3697b9..ee8813bffc 100644 --- a/frame/compat/bla_gemm_compute.c +++ b/frame/compat/bla_gemm_compute.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,7 +86,7 @@ void sgemm_compute_blis_impl ); /* Quick return. */ - if ( *m == 0 || *n == 0 ) + if ( *m == 0 || *n == 0 || ( ( *k == 0) && PASTEMAC(s,eq1)( *beta ) ) ) { /* Finalize BLIS. */ bli_finalize_auto(); @@ -214,7 +214,7 @@ void dgemm_compute_blis_impl ); /* Quick return. */ - if ( *m == 0 || *n == 0 ) + if ( *m == 0 || *n == 0 || ( ( *k == 0) && PASTEMAC(d,eq1)( *beta ) ) ) { /* Finalize BLIS. */ bli_finalize_auto(); diff --git a/frame/compat/bla_gemm_pack_get_size.c b/frame/compat/bla_gemm_pack_get_size.c index 32f2acfccb..18463a0530 100644 --- a/frame/compat/bla_gemm_pack_get_size.c +++ b/frame/compat/bla_gemm_pack_get_size.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -81,7 +81,7 @@ f77_int dgemm_pack_get_size_blis_impl f77_int n = *pn; f77_int k = *pk; - // Retreive cache-blocking parameters used in GEMM + // Retrieve cache-blocking parameters used in GEMM #if 0 // Not needed, MR and NR should do const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); @@ -207,7 +207,7 @@ f77_int sgemm_pack_get_size_blis_impl f77_int n = *pn; f77_int k = *pk; - // Retreive cache-blocking parameters used in GEMM + // Retrieve cache-blocking parameters used in GEMM #if 0 // Not needed, MR and NR should do const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c index 14ee1f15dd..233f789174 100644 --- a/frame/compat/bla_gemmt.c +++ b/frame/compat/bla_gemmt.c @@ -43,7 +43,7 @@ #define GEMMT_BLIS_IMPL(ch, blasname) \ PASTEF77S(ch,blasname) ( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ arch_t id = bli_arch_query_id(); \ - if (id == BLIS_ARCH_ZEN4) \ + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) \ { \ bli_zero_zmm(); \ } \ diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index c910e9eb1e..1a38495269 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -77,8 +77,9 @@ void PASTEF77S(ch,blasname) \ incy \ ); \ \ - if (*m == 0 || *n == 0) { \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(ch,eq0)( *alpha ) && PASTEMAC(ch,eq1)( *beta ) ) ) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ return; \ } \ \ diff --git a/frame/compat/bla_gemv.h b/frame/compat/bla_gemv.h index 3b8a7a61aa..9a1be594cf 100644 --- a/frame/compat/bla_gemv.h +++ b/frame/compat/bla_gemv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c index 4ac431e48a..224f6aca50 100644 --- a/frame/compat/bla_gemv_amd.c +++ b/frame/compat/bla_gemv_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -85,8 +85,9 @@ void PASTEF77S(ch,blasname) \ incy \ ); \ \ - if (*m == 0 || *n == 0) { \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(ch,eq0)( *alpha ) && PASTEMAC(ch,eq1)( *beta ) ) ) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ return; \ } \ \ @@ -207,7 +208,8 @@ void dgemv_blis_impl incy ); - if (*m == 0 || *n == 0) + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(d,eq0)( *alpha ) && PASTEMAC(d,eq1)( *beta ) ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; @@ -310,6 +312,31 @@ void dgemv_blis_impl NULL, NULL ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /** + * DGEMV Tiny Path + * If the matrix dimensions are within 8x8 then calculate the result + * using DGEMV Reference kernel. + */ + if ( m0 < 8 && n0 < 8 ) + { + bli_dgemv_zen_ref + ( + blis_transa, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } @@ -412,7 +439,8 @@ void sgemv_blis_impl incy ); - if (*m == 0 || *n == 0) + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(s,eq0)( *alpha ) && PASTEMAC(s,eq1)( *beta ) ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; @@ -612,7 +640,8 @@ void cgemv_blis_impl incy ); - if (*m == 0 || *n == 0) + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(c,eq0)( *alpha ) && PASTEMAC(c,eq1)( *beta ) ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; @@ -854,7 +883,8 @@ void zgemv_blis_impl incy ); - if (*m == 0 || *n == 0) + if ( *m == 0 || *n == 0 || \ + ( PASTEMAC(z,eq0)( *alpha ) && PASTEMAC(z,eq1)( *beta ) ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; diff --git a/frame/compat/bla_ger.h b/frame/compat/bla_ger.h index 290ff0d754..2312cc3ede 100644 --- a/frame/compat/bla_ger.h +++ b/frame/compat/bla_ger.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index 79c9458345..406ad3e732 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -45,7 +45,7 @@ #define HEMM_BLIS_IMPL(ch, blasname) \ PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ arch_t id = bli_arch_query_id(); \ - if (id == BLIS_ARCH_ZEN4) \ + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) \ { \ bli_zero_zmm(); \ } \ diff --git a/frame/compat/bla_hemv.h b/frame/compat/bla_hemv.h index 2c1a2526b1..f22e56379b 100644 --- a/frame/compat/bla_hemv.h +++ b/frame/compat/bla_hemv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_her.c b/frame/compat/bla_her.c old mode 100755 new mode 100644 diff --git a/frame/compat/bla_her.h b/frame/compat/bla_her.h index 627f990e73..67fb0c32e3 100644 --- a/frame/compat/bla_her.h +++ b/frame/compat/bla_her.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_her2.h b/frame/compat/bla_her2.h index 906e3d8512..310e48cf73 100644 --- a/frame/compat/bla_her2.h +++ b/frame/compat/bla_her2.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_her2k.c b/frame/compat/bla_her2k.c old mode 100755 new mode 100644 index 62bec3e764..1e81522faf --- a/frame/compat/bla_her2k.c +++ b/frame/compat/bla_her2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #define HER2K_BLIS_IMPL(ch, blasname) \ PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ arch_t id = bli_arch_query_id(); \ - if (id == BLIS_ARCH_ZEN4) \ + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) \ { \ bli_zero_zmm(); \ } \ diff --git a/frame/compat/bla_herk.c b/frame/compat/bla_herk.c old mode 100755 new mode 100644 index 337c470c1d..0ef1069ee5 --- a/frame/compat/bla_herk.c +++ b/frame/compat/bla_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #define HERK_BLIS_IMPL(ch, blasname) \ PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ arch_t id = bli_arch_query_id(); \ - if (id == BLIS_ARCH_ZEN4) \ + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) \ { \ bli_zero_zmm(); \ } \ diff --git a/frame/compat/bla_imatcopy.c b/frame/compat/bla_imatcopy.c index a3feceba48..13e59e28e2 100644 --- a/frame/compat/bla_imatcopy.c +++ b/frame/compat/bla_imatcopy.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,628 +36,1374 @@ #ifdef BLIS_ENABLE_BLAS -static dim_t bli_siMatCopy_cn(dim_t rows,dim_t cols,const float alpha,float* a,dim_t lda, dim_t ldb); - -static dim_t bli_diMatCopy_cn(dim_t rows,dim_t cols,const double alpha,double* a,dim_t lda, dim_t ldb); - -static dim_t bli_ciMatCopy_cn(dim_t rows,dim_t cols,const scomplex alpha,scomplex* a,dim_t lda, dim_t ldb); - -static dim_t bli_ciMatCopy_cr(dim_t rows,dim_t cols,const scomplex alpha,scomplex* a,dim_t lda, dim_t ldb); - -static dim_t bli_ziMatCopy_cn(dim_t rows,dim_t cols,const dcomplex alpha,dcomplex* a,dim_t lda, dim_t ldb); - -static dim_t bli_ziMatCopy_cr(dim_t rows,dim_t cols,const dcomplex alpha,dcomplex* a,dim_t lda, dim_t ldb); - -static void bli_stranspose(float* A,float* B,dim_t cols, dim_t rows); - -static void bli_dtranspose(double* A,double* B,dim_t cols, dim_t rows); - -static void bli_ctranspose(scomplex* A,scomplex* B,dim_t cols, dim_t rows); - -static void bli_ztranspose(dcomplex* A,dcomplex* B,dim_t cols, dim_t rows); - -static void bli_stranspose(float* A,float* B,dim_t cols, dim_t rows) +static dim_t bli_siMatCopy_cn + ( + dim_t rows, + dim_t cols, + const float alpha, + float* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_siMatCopy_ct + ( + dim_t rows, + dim_t cols, + const float alpha, + float* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_diMatCopy_cn + ( + dim_t rows, + dim_t cols, + const double alpha, + double* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_diMatCopy_ct + ( + dim_t rows, + dim_t cols, + const double alpha, + double* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ciMatCopy_cn + ( + dim_t rows, + dim_t cols, + const scomplex alpha, + scomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ciMatCopy_ct + ( + dim_t rows, + dim_t cols, + const scomplex alpha, + scomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ciMatCopy_cr + ( + dim_t rows, + dim_t cols, + const scomplex alpha, + scomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ciMatCopy_cc + ( + dim_t rows, + dim_t cols, + const scomplex alpha, + scomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ziMatCopy_cn + ( + dim_t rows, + dim_t cols, + const dcomplex alpha, + dcomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ziMatCopy_ct + ( + dim_t rows, + dim_t cols, + const dcomplex alpha, + dcomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ziMatCopy_cr + ( + dim_t rows, + dim_t cols, + const dcomplex alpha, + dcomplex* a, + dim_t lda, + dim_t ldb + ); + +static dim_t bli_ziMatCopy_cc + ( + dim_t rows, + dim_t cols, + const dcomplex alpha, + dcomplex* a, + dim_t lda, + dim_t ldb + ); + +void simatcopy_ + ( + f77_char* trans, + f77_int* rows, + f77_int* cols, + const float* alpha, + float* aptr, + f77_int* lda, + f77_int* ldb + ) { - for (dim_t i = 0; i < cols; i++) - for (dim_t j = 0; j < rows; j++) - B[j*cols + i] = A[i*rows +j]; + //printf("I am from simatcopy_\n"); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !( *trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R' ) ) + { + bli_print_msg( " Invalid trans setting simatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameters simatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + + if ( *trans == 'n' || *trans == 'N' ) + { + bli_siMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 't' || *trans == 'T' ) + { + bli_siMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'c' || *trans == 'C' ) + { + bli_siMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'r' || *trans == 'R' ) + { + bli_siMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } -static void bli_dtranspose(double* A,double* B,dim_t cols, dim_t rows) +void dimatcopy_ + ( + f77_char* trans, + f77_int* rows, + f77_int* cols, + const double* alpha, + double* aptr, + f77_int* lda, + f77_int* ldb + ) { - for (dim_t i = 0; i < cols; i++) - for (dim_t j = 0; j < rows; j++) - B[j*cols + i] = A[i*rows +j]; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !( *trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R' ) ) + { + bli_print_msg( " Invalid trans setting dimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameters dimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + + if ( *trans == 'n' || *trans == 'N' ) + { + bli_diMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 't' || *trans == 'T' ) + { + bli_diMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'c' || *trans == 'C' ) + { + bli_diMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'r' || *trans == 'R' ) + { + bli_diMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } -static void bli_ctranspose(scomplex* A,scomplex* B,dim_t cols, dim_t rows) +void cimatcopy_ + ( + f77_char* trans, + f77_int* rows, + f77_int* cols, + const scomplex* alpha, + scomplex* aptr, + f77_int* lda, + f77_int* ldb + ) { - for (dim_t i = 0; i < cols; i++) - for (dim_t j = 0; j < rows; j++) - B[j*cols + i] = A[i*rows +j]; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !( *trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R' ) ) + { + bli_print_msg( " Invalid trans setting cimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameters cimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + + if ( *trans == 'n' || *trans == 'N' ) + { + bli_ciMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 't' || *trans == 'T' ) + { + bli_ciMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'c' || *trans == 'C' ) + { + bli_ciMatCopy_cc + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'r' || *trans == 'R' ) + { + bli_ciMatCopy_cr + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } -static void bli_ztranspose(dcomplex* A,dcomplex* B,dim_t cols, dim_t rows) +void zimatcopy_ + ( + f77_char* trans, + f77_int* rows, + f77_int* cols, + const dcomplex* alpha, + dcomplex* aptr, + f77_int* lda, + f77_int* ldb + ) { - for (dim_t i = 0; i < cols; i++) - for (dim_t j = 0; j < rows; j++) - B[j*cols + i] = A[i*rows +j]; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !( *trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R' ) ) + { + bli_print_msg( " Invalid trans setting zimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameters zimatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + + if ( *trans == 'n' || *trans == 'N' ) + { + bli_ziMatCopy_cn + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 't' || *trans == 'T' ) + { + bli_ziMatCopy_ct + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'c' || *trans == 'C' ) + { + bli_ziMatCopy_cc + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else if ( *trans == 'r' || *trans == 'R' ) + { + bli_ziMatCopy_cr + ( + *rows, *cols, *alpha, + aptr, *lda, *ldb + ); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } -void simatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) +// suffix cn means - column major & non-trans +static dim_t bli_siMatCopy_cn(dim_t rows,dim_t cols,const float alpha,float* a,dim_t lda,dim_t ldb) { - //printf("I am from simatcopy_\n"); - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid trans setting simatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1) - { - bli_print_msg( " Invalid function parameters simatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - - if ( *trans == 'n' || *trans == 'N') - { - bli_siMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - //pre transpose - err_t r_val; - float* temp = (float* ) bli_malloc_user((*rows)*(*lda)*sizeof(float), &r_val); - bli_stranspose(aptr,temp,*lda,*rows); - - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(float)); - - bli_siMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_stranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'c' || *trans == 'C') - { - //pre transpose - err_t r_val; - float* temp = (float* ) bli_malloc_user((*rows)*(*lda)*sizeof(float), &r_val); - bli_stranspose(aptr,temp,*lda,*rows); - - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(float)); - - //bli_siMatCopy_cn(*cols,*rows,*alpha,temp,*lda,*ldb); - - bli_siMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - //post transpose - //bli_stranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_siMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + + float* s_aptr; + float* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < rows ) + { + fprintf( stderr, " Invalid trans setting bli_siMatCopy_cn() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_siMatCopy_cn() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + if ( lda == ldb && alpha == 1.0 ) + return ( 0 ); + + s_aptr = a; + d_aptr = a; + + if ( lda >= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = alpha * s_aptr[j]; + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(rows x cols) + err_t r_val; + float* buf = (float *) bli_malloc_user((rows)*(cols)*sizeof(float), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + float *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_acopy[j] = alpha * s_aptr[j]; + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_siMatCopy_cn), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } -void dimatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha,double* aptr, f77_int* lda, f77_int* ldb) +// suffix cn means - column major & non-trans +static dim_t bli_diMatCopy_cn(dim_t rows,dim_t cols,const double alpha,double* a,dim_t lda,dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid trans setting dimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1) - { - bli_print_msg( " Invalid function parameters dimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - - if ( *trans == 'n' || *trans == 'N') - { - bli_diMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - //pre transpose - err_t r_val; - double* temp = (double* ) bli_malloc_user((*rows)*(*lda)*sizeof(double), &r_val); - bli_dtranspose(aptr,temp,*lda,*rows); - - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(double)); - - bli_diMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_dtranspose(temp,aptr,*rows,*lda); - //bli_dtranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'c' || *trans == 'C') - { - //pre transpose - err_t r_val; - double* temp = (double* ) bli_malloc_user((*rows)*(*lda)*sizeof(double), &r_val); - bli_dtranspose(aptr,temp,*lda,*rows); - - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(double)); - - bli_diMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_dtranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_diMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + double* s_aptr; + double* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < rows ) + { + fprintf( stderr, " Invalid trans setting bli_diMatcopy_cn() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_diMatCopy_cn() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + if ( lda == ldb && alpha == 1.0) + return ( 0 ); + + s_aptr = a; + d_aptr = a; + + if ( lda >= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = alpha * s_aptr[j]; + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(rows x cols) + err_t r_val; + double* buf = (double *) bli_malloc_user((rows)*(cols)*sizeof(double), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + double *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_acopy[j] = alpha * s_aptr[j]; + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_diMatCopy_cn), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } -void cimatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) +// suffix cn means - column major & non-trans +static dim_t bli_ciMatCopy_cn(dim_t rows,dim_t cols,const scomplex alpha,scomplex* a,dim_t lda,dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid trans setting cimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1) - { - bli_print_msg( " Invalid function parameters cimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - - if ( *trans == 'n' || *trans == 'N') - { - bli_ciMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - //pre transpose - err_t r_val; - scomplex* temp = (scomplex* ) bli_malloc_user((*rows)*(*lda)*sizeof(scomplex), &r_val); - bli_ctranspose(aptr,temp,*lda,*rows); - - //bli_ciMatCopy_cn(*cols,*rows,*alpha,temp,*lda,*ldb); - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(scomplex)); - bli_ciMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_ctranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'c' || *trans == 'C') - { - - //pre transpose - err_t r_val; - scomplex* temp = (scomplex* ) bli_malloc_user((*rows)*(*lda)*sizeof(scomplex), &r_val); - bli_ctranspose(aptr,temp,*lda,*rows); - - //bli_ciMatCopy_cr(*cols,*rows,*alpha,temp,*lda,*ldb); - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(scomplex)); - bli_ciMatCopy_cr(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_ctranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_ciMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + scomplex* s_aptr; + scomplex* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < rows ) + { + fprintf( stderr, " Invalid trans setting bli_ciMatCopy_cn() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_ciMatCopy_cn() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + if ( lda == ldb && alpha.real == 1.0 && alpha.imag == 0.0 ) + return ( 0 ); + + s_aptr = a; + d_aptr = a; + + if ( lda >= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + scomplex temp = s_aptr[j]; + d_aptr[j].real = ( ( alpha.real * temp.real ) - ( alpha.imag * temp.imag ) ); + d_aptr[j].imag = ( ( alpha.real * temp.imag ) + ( alpha.imag * temp.real ) ); + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(rows x cols) + err_t r_val; + scomplex* buf = (scomplex *) bli_malloc_user((rows)*(cols)*sizeof(scomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + scomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + scomplex temp = s_aptr[j]; + d_acopy[j].real = ( ( alpha.real * temp.real ) - ( alpha.imag * temp.imag ) ); + d_acopy[j].imag = ( ( alpha.real * temp.imag ) + ( alpha.imag * temp.real ) ); + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ciMatCopy_cn), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } -void zimatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) +// suffix cn means - column major & non-trans +static dim_t bli_ziMatCopy_cn(dim_t rows,dim_t cols,const dcomplex alpha,dcomplex* a,dim_t lda,dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid trans setting zimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || *lda < 1 || *ldb < 1) - { - bli_print_msg( " Invalid function parameters dimatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - - if ( *trans == 'n' || *trans == 'N') - { - bli_ziMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - - //pre transpose - err_t r_val; - dcomplex* temp = (dcomplex *) bli_malloc_user((*rows)*(*lda)*sizeof(dcomplex), &r_val); - bli_ztranspose(aptr,temp,*lda,*rows); - - //bli_ziMatCopy_cn(*cols,*rows,*alpha,temp,*lda,*ldb); - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(dcomplex)); - bli_ziMatCopy_cn(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_ztranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'c' || *trans == 'C') - { - //pre transpose - err_t r_val; - dcomplex* temp = (dcomplex *) bli_malloc_user((*rows)*(*lda)*sizeof(dcomplex), &r_val); - bli_ztranspose(aptr,temp,*lda,*rows); - - //bli_ziMatCopy_cr(*cols,*rows,*alpha,temp,*lda,*ldb); - for (dim_t i = 0; i < *cols; i++) - memcpy(&aptr[i*(*lda)],&temp[i*(*lda)],(*rows)*sizeof(scomplex)); - bli_ziMatCopy_cr(*cols,*rows,*alpha,aptr,*lda,*ldb); - - //post transpose - //bli_ztranspose(temp,aptr,*lda,*cols); - bli_free_user(temp); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_ziMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + dcomplex* s_aptr; + dcomplex* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < rows ) + { + fprintf( stderr, " Invalid trans setting bli_ziMatCopy_cn() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_ziMatCopy_cn() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + if ( lda == ldb && alpha.real == 1.0 && alpha.imag == 0.0 ) + return ( 0 ); + + s_aptr = a; + d_aptr = a; + + if ( lda >= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + dcomplex temp = s_aptr[j]; + d_aptr[j].real = ( ( alpha.real * temp.real ) - ( alpha.imag * temp.imag ) ); + d_aptr[j].imag = ( ( alpha.real * temp.imag ) + ( alpha.imag * temp.real ) ); + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(rows x cols) + err_t r_val; + dcomplex* buf = (dcomplex *) bli_malloc_user((rows)*(cols)*sizeof(dcomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + dcomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + dcomplex temp = s_aptr[j]; + d_acopy[j].real = ( ( alpha.real * temp.real ) - ( alpha.imag * temp.imag ) ); + d_acopy[j].imag = ( ( alpha.real * temp.imag ) + ( alpha.imag * temp.real ) ); + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ziMatCopy_cn), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } -// suffix cn means - column major & non-trans -static dim_t bli_siMatCopy_cn(dim_t rows,dim_t cols,const float alpha,float* a,dim_t lda, dim_t ldb) +// suffix ct means - column major & trans +static dim_t bli_siMatCopy_ct(dim_t rows,dim_t cols,const float alpha,float* a,dim_t lda,dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); - dim_t i,j; - - float* s_aptr; - float* d_aptr; - - if ( rows <= 0 || cols <= 0 || a == NULL || lda < cols || ldb < cols) - { - fprintf( stderr, " Invalid trans setting bli_siMatCopy_cn() %ld %ld %ld %ld \n", - ( long )rows, ( long )cols, ( long )lda, ( long )ldb); - bli_print_msg( " Invalid function parameters bli_siMatCopy_cn() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); - return (0); - } - - if ( lda == ldb && alpha == 1.0) - return (0); - - s_aptr = a; - d_aptr = a; - if ( alpha == 0.0 ) - { - for ( i=0; i= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + scomplex temp = s_aptr[j]; + d_aptr[j].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_aptr[j].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(cols x rows) + err_t r_val; + scomplex* buf = (scomplex *) bli_malloc_user((rows)*(cols)*sizeof(scomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + scomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + scomplex temp = s_aptr[j]; + d_acopy[j].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_acopy[j].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ciMatCopy_cr), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } // suffix cr means - column major & conjugate static dim_t bli_ziMatCopy_cr(dim_t rows,dim_t cols,const dcomplex alpha,dcomplex* a,dim_t lda, dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); - dim_t i,j; - dcomplex* s_aptr; - dcomplex* d_aptr; - - if ( rows <= 0 || cols <= 0 || a == NULL || lda < cols || ldb < cols) - { - fprintf( stderr, " Invalid trans setting bli_ziMatCopy_cr() %ld %ld %ld %ld \n", - ( long )rows, ( long )cols, ( long )lda, ( long )ldb); - bli_print_msg( " Invalid function parameters bli_ziMatCopy_cr() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); - return (0); - } - s_aptr = a; - d_aptr = a; - if ( alpha.real == 0.0 && alpha.imag == 0.0 ) - { - for ( i=0; i= ldb ) + { + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + dcomplex temp = s_aptr[j]; + d_aptr[j].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_aptr[j].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_aptr += ldb; + } + } + else + { + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(cols x rows) + err_t r_val; + dcomplex* buf = (dcomplex *) bli_malloc_user((rows)*(cols)*sizeof(dcomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + dcomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + dcomplex temp = s_aptr[j]; + d_acopy[j].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_acopy[j].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_acopy += rows; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + d_aptr[j] = d_acopy[j]; + } + d_acopy += rows; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ziMatCopy_cr), &mem_fail_info, (f77_int)16); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); +} + +// suffix cc means - column major & conjugate trans +static dim_t bli_ciMatCopy_cc(dim_t rows,dim_t cols,const scomplex alpha,scomplex* a,dim_t lda,dim_t ldb) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + + scomplex* s_aptr; + scomplex* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < cols ) + { + fprintf( stderr, " Invalid trans setting bli_ciMatCopy_ct() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_ciMatCopy_ct() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + s_aptr = a; + d_aptr = a; + + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(cols x rows) + err_t r_val; + scomplex* buf = (scomplex *) bli_malloc_user((cols)*(rows)*sizeof(scomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + scomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + scomplex temp = s_aptr[j]; + d_acopy[j * cols].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_acopy[j * cols].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_acopy += 1; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( j = 0; j < rows; j++ ) + { + for ( i = 0; i < cols; i++ ) + { + d_aptr[i] = d_acopy[i]; + } + d_acopy += cols; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ciMatCopy_ct), &mem_fail_info, (f77_int)16); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); +} + +// suffix cc means - column major & conjugate trans +static dim_t bli_ziMatCopy_cc(dim_t rows,dim_t cols,const dcomplex alpha,dcomplex* a,dim_t lda,dim_t ldb) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); + dim_t i,j; + + dcomplex* s_aptr; + dcomplex* d_aptr; + + if ( rows <= 0 || cols <= 0 || a == NULL || lda < rows || ldb < cols ) + { + fprintf( stderr, " Invalid trans setting bli_ziMatCopy_ct() %ld %ld %ld %ld \n", + ( long )rows, ( long )cols, ( long )lda, ( long )ldb); + bli_print_msg( " Invalid function parameters bli_ziMatCopy_ct() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); + return ( 0 ); + } + + s_aptr = a; + d_aptr = a; + + // Acquring memory for auxiliary buffer(in case lda < ldb). This is + // needed in order to avoid overwriting subsequent reads from the input. + // This extra buffer is allocated exactly the amount of memory that + // is needed to store the required elements from input(cols x rows) + err_t r_val; + dcomplex* buf = (dcomplex *) bli_malloc_user((cols)*(rows)*sizeof(dcomplex), &r_val); + + if( buf != NULL ) + { + // Loading from input, storing onto auxiliary buffer + dcomplex *d_acopy = buf; + for ( i = 0; i < cols ; i++ ) + { + for ( j = 0; j < rows; j++ ) + { + dcomplex temp = s_aptr[j]; + d_acopy[j * cols].real = ( ( alpha.imag * temp.imag ) + ( alpha.real * temp.real ) ); + d_acopy[j * cols].imag = ( ( alpha.imag * temp.real ) - ( alpha.real * temp.imag ) ); + } + s_aptr += lda; + d_acopy += 1; + } + + // Loading from auxiliary buffer, storing onto output + d_acopy = buf; + for ( j = 0; j < rows; j++ ) + { + for ( i = 0; i < cols; i++ ) + { + d_aptr[i] = d_acopy[i]; + } + d_acopy += cols; + d_aptr += ldb; + } + + bli_free_user(buf); + } + else + { + f77_int mem_fail_info = -10; + xerbla_(MKSTR(bli_ziMatCopy_ct), &mem_fail_info, (f77_int)16); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ( 0 ); } #endif diff --git a/frame/compat/bla_nrm2.c b/frame/compat/bla_nrm2.c old mode 100755 new mode 100644 diff --git a/frame/compat/bla_nrm2.h b/frame/compat/bla_nrm2.h index c4e9ec8b4d..c3922ca002 100644 --- a/frame/compat/bla_nrm2.h +++ b/frame/compat/bla_nrm2.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/bla_omatcopy.c b/frame/compat/bla_omatcopy.c index 80a9650565..18f1c29d5d 100644 --- a/frame/compat/bla_omatcopy.c +++ b/frame/compat/bla_omatcopy.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,902 +36,911 @@ #ifdef BLIS_ENABLE_BLAS -static dim_t bli_soMatCopy_cn(dim_t rows,dim_t cols,const float alpha,const float* a,dim_t lda,float* b,dim_t ldb); +static dim_t bli_soMatCopy_cn(dim_t rows, dim_t cols, const float alpha, const float* a, dim_t lda, float* b, dim_t ldb); -static dim_t bli_soMatCopy_ct(dim_t rows,dim_t cols,const float alpha,const float* a,dim_t lda,float* b,dim_t ldb); +static dim_t bli_soMatCopy_ct(dim_t rows, dim_t cols, const float alpha, const float* a, dim_t lda, float* b, dim_t ldb); -static dim_t bli_doMatCopy_cn(dim_t rows,dim_t cols,const double alpha,const double* a,dim_t lda,double* b,dim_t ldb); +static dim_t bli_doMatCopy_cn(dim_t rows, dim_t cols, const double alpha, const double* a, dim_t lda, double* b, dim_t ldb); -static dim_t bli_doMatCopy_ct(dim_t rows,dim_t cols,const double alpha,const double* a,dim_t lda,double* b,dim_t ldb); +static dim_t bli_doMatCopy_ct(dim_t rows, dim_t cols, const double alpha, const double* a, dim_t lda, double* b, dim_t ldb); -static dim_t bli_coMatCopy_cn(dim_t rows,dim_t cols,const scomplex alpha,const scomplex* a,dim_t lda,scomplex* b,dim_t ldb); +static dim_t bli_coMatCopy_cn(dim_t rows, dim_t cols, const scomplex alpha, const scomplex* a, dim_t lda, scomplex* b, dim_t ldb); -static dim_t bli_coMatCopy_ct(dim_t rows,dim_t cols,const scomplex alpha,const scomplex* a,dim_t lda,scomplex* b,dim_t ldb); +static dim_t bli_coMatCopy_ct(dim_t rows, dim_t cols, const scomplex alpha, const scomplex* a, dim_t lda, scomplex* b, dim_t ldb); -static dim_t bli_coMatCopy_cr(dim_t rows,dim_t cols,const scomplex alpha,const scomplex* a,dim_t lda,scomplex* b,dim_t ldb); +static dim_t bli_coMatCopy_cr(dim_t rows, dim_t cols, const scomplex alpha, const scomplex* a, dim_t lda, scomplex* b, dim_t ldb); -static dim_t bli_coMatCopy_cc(dim_t rows,dim_t cols,const scomplex alpha,const scomplex* a,dim_t lda,scomplex* b,dim_t ldb); +static dim_t bli_coMatCopy_cc(dim_t rows, dim_t cols, const scomplex alpha, const scomplex* a, dim_t lda, scomplex* b, dim_t ldb); -static dim_t bli_zoMatCopy_cn(dim_t rows,dim_t cols,const dcomplex alpha,const dcomplex* a,dim_t lda,dcomplex* b,dim_t ldb); +static dim_t bli_zoMatCopy_cn(dim_t rows, dim_t cols, const dcomplex alpha, const dcomplex* a, dim_t lda, dcomplex* b, dim_t ldb); -static dim_t bli_zoMatCopy_ct(dim_t rows,dim_t cols,const dcomplex alpha,const dcomplex* a,dim_t lda,dcomplex* b,dim_t ldb); +static dim_t bli_zoMatCopy_ct(dim_t rows, dim_t cols, const dcomplex alpha, const dcomplex* a, dim_t lda, dcomplex* b, dim_t ldb); -static dim_t bli_zoMatCopy_cr(dim_t rows,dim_t cols,const dcomplex alpha,const dcomplex* a,dim_t lda,dcomplex* b,dim_t ldb); +static dim_t bli_zoMatCopy_cr(dim_t rows, dim_t cols, const dcomplex alpha, const dcomplex* a, dim_t lda, dcomplex* b, dim_t ldb); -static dim_t bli_zoMatCopy_cc(dim_t rows,dim_t cols,const dcomplex alpha,const dcomplex* a,dim_t lda,dcomplex* b,dim_t ldb); +static dim_t bli_zoMatCopy_cc(dim_t rows, dim_t cols, const dcomplex alpha, const dcomplex* a, dim_t lda, dcomplex* b, dim_t ldb); void somatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid value of trans parameter in somatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) - { - bli_print_msg( " Invalid function parameter in somatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - if ( *trans == 'n' || *trans == 'N') - { - bli_soMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - bli_soMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'c' || *trans == 'C') - { - bli_soMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_soMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + if ( !(*trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R')) + { + bli_print_msg( " Invalid value of trans parameter in somatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameter in somatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + if ( *trans == 'n' || *trans == 'N') + { + bli_soMatCopy_cn(*rows, *cols, *alpha, aptr, *lda, bptr, *ldb); + } + else if ( *trans == 't' || *trans == 'T') + { + bli_soMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'c' || *trans == 'C') + { + bli_soMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'r' || *trans == 'R') + { + bli_soMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } void domatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid value of trans parameter in domatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) - { - bli_print_msg( " Invalid function parameter in domatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - if ( *trans == 'n' || *trans == 'N') - { - bli_doMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - bli_doMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'c' || *trans == 'C') - { - bli_doMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_doMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !(*trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R')) + { + bli_print_msg( " Invalid value of trans parameter in domatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameter in domatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + if ( *trans == 'n' || *trans == 'N') + { + bli_doMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 't' || *trans == 'T') + { + bli_doMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'c' || *trans == 'C') + { + bli_doMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'r' || *trans == 'R') + { + bli_doMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } void comatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid value of trans parameter in comatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) - { - bli_print_msg( " Invalid function parameter in comatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - if ( *trans == 'n' || *trans == 'N') - { - bli_coMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - bli_coMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'c' || *trans == 'C') - { - bli_coMatCopy_cc(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_coMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !(*trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R')) + { + bli_print_msg( " Invalid value of trans parameter in comatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || + (*lda < 1 ) || (*ldb < 1 ) ) + { + bli_print_msg( " Invalid function parameter in comatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + if ( *trans == 'n' || *trans == 'N') + { + bli_coMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 't' || *trans == 'T') + { + bli_coMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'c' || *trans == 'C') + { + bli_coMatCopy_cc(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'r' || *trans == 'R') + { + bli_coMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } void zomatcopy_ (f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - //bli_init_once(); - if ( !(*trans == 'n' || *trans == 'N' || - *trans == 't' || *trans == 'T' || - *trans == 'c' || *trans == 'C' || - *trans == 'r' || *trans == 'R')) - { - bli_print_msg( " Invalid value of trans parameter in zomatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); - return ; - } - if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) - { - bli_print_msg( " Invalid function parameter in zomatcopy_() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); - return ; - } - if ( *trans == 'n' || *trans == 'N') - { - bli_zoMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 't' || *trans == 'T') - { - bli_zoMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'c' || *trans == 'C') - { - bli_zoMatCopy_cc(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else if ( *trans == 'r' || *trans == 'R') - { - bli_zoMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); - } - else - { - // do nothing - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return ; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + //bli_init_once(); + if ( !(*trans == 'n' || *trans == 'N' || + *trans == 't' || *trans == 'T' || + *trans == 'c' || *trans == 'C' || + *trans == 'r' || *trans == 'R')) + { + bli_print_msg( " Invalid value of trans parameter in zomatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid value for trans parameter"); + return ; + } + if ( *rows <= 0 || *cols <= 0 || alpha == NULL || aptr == NULL || bptr == NULL || *lda < 1 || *ldb < 1 ) + { + bli_print_msg( " Invalid function parameter in zomatcopy_() .", __FILE__, __LINE__ ); + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "Invalid function parameters"); + return ; + } + if ( *trans == 'n' || *trans == 'N') + { + bli_zoMatCopy_cn(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 't' || *trans == 'T') + { + bli_zoMatCopy_ct(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'c' || *trans == 'C') + { + bli_zoMatCopy_cc(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else if ( *trans == 'r' || *trans == 'R') + { + bli_zoMatCopy_cr(*rows,*cols,*alpha,aptr,*lda,bptr,*ldb); + } + else + { + // do nothing + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return ; } // suffix cn means - column major & non-trans -static dim_t bli_soMatCopy_cn(dim_t rows,dim_t cols,const float alpha,const float* a,dim_t lda,float* b,dim_t ldb) +static dim_t bli_soMatCopy_cn(dim_t rows, dim_t cols, const float alpha, const float* a, dim_t lda, float* b, dim_t ldb) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); - dim_t i,j; - const float* aptr; - float* bptr; - if ( rows <= 0 || cols <= 0 || a == NULL || b == NULL || lda < rows || ldb < rows ) - { - bli_print_msg( " Invalid function parameter in bli_soMatCopy_cn() .", __FILE__, __LINE__ ); - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "Invalid function parameters"); - return (0); - } - - aptr = a; - bptr = b; - - if ( alpha == 0.0 ) - { - for ( i=0; i 1. + // Time complexity of TRSM is (M^2 * N) in left variants + // and (N^2 * M) in right variants. + // Therefore time taken by Small path for left variant will be + // (M^2 * N) + // and time taken by Native path for left variant will be + // (M^2 * N) / S + X + // We should take small code path when + // (M^2 * N) < (M^2 * N) / S + X + // solving this gives us + // (M^2 * N) < (X * S) / ( S - 1) + // Here RHS is constant, which can be found using empirical data + // (X * S) / ( S - 1) is found to be around 6.3e6 on Turin + // In order the reduce the possiblity of overflow, taking log on + // both sides gives us + // 2log(m) + log(n) < 6.8 for left variant + if ( ( blis_side == BLIS_LEFT ) && + ( (log10(n0) + (2*log10(m0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_AVX512; + } + else if ( ( blis_side == BLIS_RIGHT ) && + ( (log10(m0) + (2*log10(n0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_AVX512; + } + break; +#endif // BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || + (is_parallel && (m0+n0)<200)) + { /* For sizes where m and n < 50,avx2 kernels are performing better, - except for sizes where n is multiple of 8.*/ + except for sizes where n is multiple of 8.*/ if (((n0 % 8 == 0) && (n0 < 50)) || ((m0 > 50) && (n0 > 50))) { ker_ft = bli_trsm_small_AVX512; @@ -1127,36 +1176,61 @@ void dtrsm_blis_impl { ker_ft = bli_trsm_small; } - break; + } + break; #endif // BLIS_KERNELS_ZEN4 - case BLIS_ARCH_ZEN: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN3: - default: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || + (is_parallel && (m0+n0)<200)) + { ker_ft = bli_trsm_small; - break; - } + } + break; } #ifdef BLIS_ENABLE_OPENMP - if( (ker_ft == NULL) && (is_parallel) && - ((dim_a < 2500) && (size_b < 5e6)) ) + switch(id) { - switch(id) - { - case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN5: #if defined(BLIS_KERNELS_ZEN4) + if( (is_parallel) && n0 > 10 && m0 > 10 ) + { + if ( ( blis_side == BLIS_LEFT ) && + ( (log10(n0) + (2*log10(m0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + else if ( ( blis_side == BLIS_RIGHT ) && + ( (log10(m0) + (2*log10(n0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + } + break; +#endif// BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + if( (ker_ft == NULL) && (is_parallel) && + ((dim_a < 2500) && (size_b < 5e6)) ) + { ker_ft = bli_trsm_small_mt_AVX512; - break; + } + break; #endif// BLIS_KERNELS_ZEN4 - case BLIS_ARCH_ZEN: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN3: - default: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + if( (ker_ft == NULL) && (is_parallel) && + ((dim_a < 2500) && (size_b < 5e6)) ) + { ker_ft = bli_trsm_small_mt; - break; + } + break; } - } #endif// BLIS_ENABLE_OPENMP if(ker_ft) @@ -1174,15 +1248,29 @@ void dtrsm_blis_impl } // bli_cpuid_is_avx2fma3_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); + //bli_trsmnat + //( + // blis_side, + // &alphao, + // &ao, + // &bo, + // NULL, + // NULL + //); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(trsm,_front)( blis_side, &alphao, &ao, &bo, cntx, &rntm_l, NULL ); \ + AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *side, *m, *n); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) /* Finalize BLIS. */ @@ -1205,7 +1293,7 @@ void dtrsm_ dtrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } @@ -1343,15 +1431,19 @@ void ztrsm_blis_impl } else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) { - bli_zscalv_ex - ( - conja, - m0, - (dcomplex*)alpha, - (dcomplex*)b, rs_b, - NULL, - NULL - ); + /* Avoid alpha scaling when alpha is one */ + if ( !PASTEMAC(z, eq1)(*alpha) ) + { + bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + } if(blis_diaga == BLIS_NONUNIT_DIAG) { dcomplex inva = {1.0, 0.0}; @@ -1447,15 +1539,19 @@ void ztrsm_blis_impl } else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) { - bli_zscalv_ex - ( - conja, - n0, - (dcomplex*)alpha, - (dcomplex*)b, cs_b, - NULL, - NULL - ); + /* Avoid alpha scaling when alpha is one */ + if ( !PASTEMAC(z, eq1)(*alpha) ) + { + bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + } if(blis_diaga == BLIS_NONUNIT_DIAG) { dcomplex inva = {1.0, 0.0}; @@ -1525,12 +1621,27 @@ void ztrsm_blis_impl #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including 'generic'. // Non-AVX2+FMA3 platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx2fma3_supported() == TRUE) + if ( bli_cpuid_is_avx2fma3_supported() == TRUE ) { /* bli_ztrsm_small is performing better existing native * implementations for [m,n]<=1000 for single thread. * In case of multithread when [m,n]<=128 single thread implementation * is doing better than native multithread */ + typedef err_t (*ztrsm_small_ker_ft) + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl, + bool is_parallel + ); + err_t status = BLIS_NOT_YET_IMPLEMENTED; + + // trsm small kernel function pointer definition + ztrsm_small_ker_ft ker_ft = NULL; + arch_t id = bli_arch_query_id(); bool is_parallel = bli_thread_get_is_parallel(); dim_t dim_a = n0; if (blis_side == BLIS_LEFT) @@ -1538,42 +1649,91 @@ void ztrsm_blis_impl // size of output matrix(B) dim_t size_b = m0*n0; - if((!is_parallel && m0<=500 && n0<=500) || - (is_parallel && (m0+n0)<128) || - (dim_a<35 && size_b<3500)) +#if defined(BLIS_ENABLE_OPENMP) && defined(BLIS_KERNELS_ZEN4) + if (( is_parallel ) && + ( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ) && + ( id == BLIS_ARCH_ZEN4 )) { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL, - is_parallel - ); - if (status == BLIS_SUCCESS) + if (!bli_obj_has_conj(&ao)) { - AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + ker_ft = bli_trsm_small_mt_AVX512; + } + else + { + ker_ft = bli_trsm_small_mt; + } + } +#endif + if( ( ker_ft == NULL ) && + ( ( ( !is_parallel ) && + ( (( m0 <= 500 ) && ( n0 <= 500 )) || ( (dim_a < 75) && (size_b < 3.2e5)))) || + ( ( is_parallel ) && + ( (m0 + n0 < 180) || (size_b < 5000) ) ) + ) + ) + { + switch (id) + { + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + // ZTRSM AVX512 code path do not support + // conjugate + if (!bli_obj_has_conj(&ao)) + { + ker_ft = bli_trsm_small_AVX512; + } + else + { + ker_ft = bli_trsm_small; + } + break; +#endif // BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + ker_ft = bli_trsm_small; + break; } } + if(ker_ft) + { + status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel); + } + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } // bli_cpuid_is_avx2fma3_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); + //bli_trsmnat + //( + // blis_side, + // &alphao, + // &ao, + // &bo, + // NULL, + // NULL + //); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(trsm,_front)( blis_side, &alphao, &ao, &bo, cntx, &rntm_l, NULL ); \ AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) @@ -1597,7 +1757,7 @@ void ztrsm_ ztrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } @@ -1949,15 +2109,28 @@ void ctrsm_blis_impl } // bli_cpuid_is_avx2fma3_supported #endif - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); + //bli_trsmnat + //( + // blis_side, + // &alphao, + // &ao, + // &bo, + // NULL, + // NULL + //); + + /* Default to using native execution. */ + ind_t im = BLIS_NAT; + + /* Obtain a valid context from the gks using the induced + method id determined above. */ + cntx_t* cntx = bli_gks_query_ind_cntx( im, dt ); + + rntm_t rntm_l; + bli_rntm_init_from_global( &rntm_l ); + + /* Invoke the operation's front-end and request the default control tree. */ + PASTEMAC(trsm,_front)( blis_side, &alphao, &ao, &bo, cntx, &rntm_l, NULL ); \ AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(c), *side, *m, *n); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) @@ -1981,7 +2154,7 @@ void ctrsm_ ctrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); #if defined(BLIS_KERNELS_ZEN4) arch_t id = bli_arch_query_id(); - if (id == BLIS_ARCH_ZEN4) + if (id == BLIS_ARCH_ZEN5 || id == BLIS_ARCH_ZEN4) { bli_zero_zmm(); } diff --git a/frame/compat/bla_trsv.h b/frame/compat/bla_trsv.h index 47b02935d9..267cb5fef6 100644 --- a/frame/compat/bla_trsv.h +++ b/frame/compat/bla_trsv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.c b/frame/compat/cblas/f77_sub/f77_amax_sub.c index c394ed4d40..3d964cce4c 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.h b/frame/compat/cblas/f77_sub/f77_amax_sub.h index 35d501ba4a..dd23ca212e 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_amin_sub.c b/frame/compat/cblas/f77_sub/f77_amin_sub.c index 244928d7bb..615f648ef9 100644 --- a/frame/compat/cblas/f77_sub/f77_amin_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amin_sub.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.c b/frame/compat/cblas/f77_sub/f77_asum_sub.c index befac150e0..80f251c160 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.c +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.h b/frame/compat/cblas/f77_sub/f77_asum_sub.h index de3d99bfc9..f2cb6faabd 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.h +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.c b/frame/compat/cblas/f77_sub/f77_dot_sub.c index f497ab97f0..80d8e37030 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.c +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.h b/frame/compat/cblas/f77_sub/f77_dot_sub.h index 54a40a9a02..95382975ac 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.h +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c index 72fa07593a..3e8e7dd312 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h index dbe2809741..ee77b54b50 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/cblas/src/cblas.h b/frame/compat/cblas/src/cblas.h index 7d57b15bf5..a2b805c621 100644 --- a/frame/compat/cblas/src/cblas.h +++ b/frame/compat/cblas/src/cblas.h @@ -1,6 +1,10 @@ /* - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -181,7 +185,7 @@ void BLIS_EXPORT_BLAS cblas_zaxpby(f77_int N, const void *alpha, /* - * Routines with S and D prefix only + * Routines with S D C Z CS and ZD prefixes */ void BLIS_EXPORT_BLAS cblas_srotg(float *a, float *b, float *c, float *s); void BLIS_EXPORT_BLAS cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); @@ -197,10 +201,13 @@ void BLIS_EXPORT_BLAS cblas_drot(f77_int N, double *X, f77_int incX, void BLIS_EXPORT_BLAS cblas_drotm(f77_int N, double *X, f77_int incX, double *Y, f77_int incY, const double *P); +void BLIS_EXPORT_BLAS cblas_crotg(void *a, void *b, float *c, void *s); +void BLIS_EXPORT_BLAS cblas_csrot(f77_int N, void *X, f77_int incX, + void *Y, f77_int incY, const float c, const float s); +void BLIS_EXPORT_BLAS cblas_zrotg(void *a, void *b, double *c, void *s); +void BLIS_EXPORT_BLAS cblas_zdrot(f77_int N, void *X, f77_int incX, + void *Y, f77_int incY, const double c, const double s); -/* - * Routines with S D C Z CS and ZD prefixes - */ void BLIS_EXPORT_BLAS cblas_sscal(f77_int N, float alpha, float *X, f77_int incX); void BLIS_EXPORT_BLAS cblas_dscal(f77_int N, double alpha, double *X, f77_int incX); void BLIS_EXPORT_BLAS cblas_cscal(f77_int N, const void *alpha, void *X, f77_int incX); diff --git a/frame/compat/cblas/src/cblas_crotg.c b/frame/compat/cblas/src/cblas_crotg.c new file mode 100644 index 0000000000..d6abc39ff9 --- /dev/null +++ b/frame/compat/cblas/src/cblas_crotg.c @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_crotg.c + * + * The program is a C interface to crotg. + * + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_crotg( void *a, void *b, float *c, void *s ) +{ + F77_crotg((scomplex*)a, (scomplex*)b, c, (scomplex*)s); +} +#endif diff --git a/frame/compat/cblas/src/cblas_csrot.c b/frame/compat/cblas/src/cblas_csrot.c new file mode 100644 index 0000000000..af80700b90 --- /dev/null +++ b/frame/compat/cblas/src/cblas_csrot.c @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_csrot.c + * + * The program is a C interface to csrot. + * + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_csrot( f77_int N, void *X, f77_int incX, void *Y, + f77_int incY, const float c, const float s ) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_csrot( &F77_N, (scomplex*)X, &F77_incX, (scomplex*)Y, &F77_incY, &c, &s ); +} +#endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 18bbad51b7..ce9400f31e 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -7,10 +7,42 @@ * * (Heavily hacked down from the original) * - * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - * */ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + #ifndef CBLAS_F77_H #define CBLAS_F77_H @@ -224,6 +256,10 @@ #define F77_drotmg drotmg_blis_impl #define F77_drot drot_blis_impl #define F77_drotm drotm_blis_impl +#define F77_crotg crotg_blis_impl +#define F77_csrot csrot_blis_impl +#define F77_zrotg zrotg_blis_impl +#define F77_zdrot zdrot_blis_impl #define F77_sswap sswap_blis_impl #define F77_scopy scopy_blis_impl #define F77_saxpy saxpy_blis_impl diff --git a/frame/compat/cblas/src/cblas_xerbla.c b/frame/compat/cblas/src/cblas_xerbla.c index ebe6bd8009..16bece55f5 100644 --- a/frame/compat/cblas/src/cblas_xerbla.c +++ b/frame/compat/cblas/src/cblas_xerbla.c @@ -7,13 +7,17 @@ #include "cblas.h" #include "cblas_f77.h" +// The global rntm_t structure. (The definition resides in bli_rntm.c.) +extern rntm_t global_rntm; + +// Make thread settings local to each thread calling BLIS routines. +// (The definition resides in bli_rntm.c.) +extern BLIS_THREAD_LOCAL rntm_t tl_rntm; + void cblas_xerbla(f77_int info, const char *rout, const char *form, ...) { extern int RowMajorStrg; char empty[1] = ""; - va_list argptr; - - va_start(argptr, form); if (RowMajorStrg) { @@ -60,12 +64,36 @@ void cblas_xerbla(f77_int info, const char *rout, const char *form, ...) else if (info == 6) info = 8; } } + if (info) - fprintf(stderr, "Parameter %jd to routine %s was incorrect\n", ( intmax_t )info, rout); - vfprintf(stderr, form, argptr); - va_end(argptr); - if (info && !info) - F77_xerbla(empty, &info, 0); /* Force link of our F77 error handler */ - exit(-1); + { + // Make sure rntm variables are initialized. + bli_init_once(); + + // Store info value in thread-local rntm data structure. + gint_t info_value = (gint_t) info; + bli_rntm_set_info_value_only( info_value, &tl_rntm ); + + bool print_on_error = bli_rntm_print_on_error( &global_rntm ); + if (print_on_error) + { + va_list argptr; + va_start(argptr, form); + + fprintf(stderr, "Parameter %d to routine %s was incorrect\n", (int)info, rout); + vfprintf(stderr, form, argptr); + va_end(argptr); + } + + bool stop_on_error = bli_rntm_stop_on_error( &global_rntm ); + if (stop_on_error) + { + bli_abort(); + } + + if (info && !info) + F77_xerbla(empty, &info, 0); /* Force link of our F77 error handler */ + } } #endif + diff --git a/frame/compat/cblas/src/cblas_zdrot.c b/frame/compat/cblas/src/cblas_zdrot.c new file mode 100644 index 0000000000..5337d9a284 --- /dev/null +++ b/frame/compat/cblas/src/cblas_zdrot.c @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_zdrot.c + * + * The program is a C interface to zdrot. + * + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zdrot( f77_int N, void *X, f77_int incX, void *Y, + f77_int incY, const double c, const double s ) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zdrot( &F77_N, (dcomplex*)X, &F77_incX, (dcomplex*)Y, &F77_incY, &c, &s ); +} +#endif diff --git a/frame/compat/cblas/src/cblas_zrotg.c b/frame/compat/cblas/src/cblas_zrotg.c new file mode 100644 index 0000000000..275c6660f4 --- /dev/null +++ b/frame/compat/cblas/src/cblas_zrotg.c @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_zrotg.c + * + * The program is a C interface to zrotg. + * + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zrotg( void *a, void *b, double *c, void *s ) +{ + F77_zrotg((dcomplex*)a, (dcomplex*)b, c, (dcomplex*)s); +} +#endif diff --git a/frame/compat/f2c/bla_gbmv.c b/frame/compat/f2c/bla_gbmv.c index 671153b950..1fa41dc92f 100644 --- a/frame/compat/f2c/bla_gbmv.c +++ b/frame/compat/f2c/bla_gbmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_gbmv.h b/frame/compat/f2c/bla_gbmv.h index 39df264978..2990a365cb 100644 --- a/frame/compat/f2c/bla_gbmv.h +++ b/frame/compat/f2c/bla_gbmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hbmv.c b/frame/compat/f2c/bla_hbmv.c index 3398493afb..43403fd0b7 100644 --- a/frame/compat/f2c/bla_hbmv.c +++ b/frame/compat/f2c/bla_hbmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hbmv.h b/frame/compat/f2c/bla_hbmv.h index 1d8bda65ff..748074bd3e 100644 --- a/frame/compat/f2c/bla_hbmv.h +++ b/frame/compat/f2c/bla_hbmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpmv.c b/frame/compat/f2c/bla_hpmv.c index 446eb24a49..4f64d6260b 100644 --- a/frame/compat/f2c/bla_hpmv.c +++ b/frame/compat/f2c/bla_hpmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpmv.h b/frame/compat/f2c/bla_hpmv.h index c7f1bc0822..3f23f89d2f 100644 --- a/frame/compat/f2c/bla_hpmv.h +++ b/frame/compat/f2c/bla_hpmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpr.c b/frame/compat/f2c/bla_hpr.c index a4300c6463..586975f5c7 100644 --- a/frame/compat/f2c/bla_hpr.c +++ b/frame/compat/f2c/bla_hpr.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpr.h b/frame/compat/f2c/bla_hpr.h index 24c7b238d4..2eabfab02d 100644 --- a/frame/compat/f2c/bla_hpr.h +++ b/frame/compat/f2c/bla_hpr.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpr2.c b/frame/compat/f2c/bla_hpr2.c index 5f4b9c0b2d..b488a8c6f7 100644 --- a/frame/compat/f2c/bla_hpr2.c +++ b/frame/compat/f2c/bla_hpr2.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_hpr2.h b/frame/compat/f2c/bla_hpr2.h index ccffc7c5b7..5f8633f990 100644 --- a/frame/compat/f2c/bla_hpr2.h +++ b/frame/compat/f2c/bla_hpr2.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rot.c b/frame/compat/f2c/bla_rot.c index cb5ef37f3c..d70b88ddfb 100644 --- a/frame/compat/f2c/bla_rot.c +++ b/frame/compat/f2c/bla_rot.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rot.h b/frame/compat/f2c/bla_rot.h index f6c28d5a3e..8dda48274d 100644 --- a/frame/compat/f2c/bla_rot.h +++ b/frame/compat/f2c/bla_rot.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotg.c b/frame/compat/f2c/bla_rotg.c index b892e3dfee..ecce3660f9 100644 --- a/frame/compat/f2c/bla_rotg.c +++ b/frame/compat/f2c/bla_rotg.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotg.h b/frame/compat/f2c/bla_rotg.h index 8558d4fec4..4c1e619d82 100644 --- a/frame/compat/f2c/bla_rotg.h +++ b/frame/compat/f2c/bla_rotg.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotm.c b/frame/compat/f2c/bla_rotm.c index 4ce727abd1..608a845cdd 100644 --- a/frame/compat/f2c/bla_rotm.c +++ b/frame/compat/f2c/bla_rotm.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotm.h b/frame/compat/f2c/bla_rotm.h index ce33623c5d..bc74d5f4b2 100644 --- a/frame/compat/f2c/bla_rotm.h +++ b/frame/compat/f2c/bla_rotm.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotmg.c b/frame/compat/f2c/bla_rotmg.c index a599d74dad..0de1537c2d 100644 --- a/frame/compat/f2c/bla_rotmg.c +++ b/frame/compat/f2c/bla_rotmg.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_rotmg.h b/frame/compat/f2c/bla_rotmg.h index 5595842145..e264e2a191 100644 --- a/frame/compat/f2c/bla_rotmg.h +++ b/frame/compat/f2c/bla_rotmg.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_sbmv.c b/frame/compat/f2c/bla_sbmv.c index ec9236bf51..716d9ffa41 100644 --- a/frame/compat/f2c/bla_sbmv.c +++ b/frame/compat/f2c/bla_sbmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_sbmv.h b/frame/compat/f2c/bla_sbmv.h index 56a89cfe2d..a70d1caa10 100644 --- a/frame/compat/f2c/bla_sbmv.h +++ b/frame/compat/f2c/bla_sbmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spmv.c b/frame/compat/f2c/bla_spmv.c index d0e8e5c58a..8ba132d0d0 100644 --- a/frame/compat/f2c/bla_spmv.c +++ b/frame/compat/f2c/bla_spmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spmv.h b/frame/compat/f2c/bla_spmv.h index 7652207ce0..5c4a42a54b 100644 --- a/frame/compat/f2c/bla_spmv.h +++ b/frame/compat/f2c/bla_spmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spr.c b/frame/compat/f2c/bla_spr.c index fbc3c81b28..24933c9eb8 100644 --- a/frame/compat/f2c/bla_spr.c +++ b/frame/compat/f2c/bla_spr.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spr.h b/frame/compat/f2c/bla_spr.h index 2b2da5bb19..cfb217b79e 100644 --- a/frame/compat/f2c/bla_spr.h +++ b/frame/compat/f2c/bla_spr.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spr2.c b/frame/compat/f2c/bla_spr2.c index beb2d92c0d..9202c3fa6d 100644 --- a/frame/compat/f2c/bla_spr2.c +++ b/frame/compat/f2c/bla_spr2.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_spr2.h b/frame/compat/f2c/bla_spr2.h index 2567cea9ae..9e0120e184 100644 --- a/frame/compat/f2c/bla_spr2.h +++ b/frame/compat/f2c/bla_spr2.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tbmv.c b/frame/compat/f2c/bla_tbmv.c index ebc587df6c..31a57805ef 100644 --- a/frame/compat/f2c/bla_tbmv.c +++ b/frame/compat/f2c/bla_tbmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tbmv.h b/frame/compat/f2c/bla_tbmv.h index c91d9579f7..bc1465631c 100644 --- a/frame/compat/f2c/bla_tbmv.h +++ b/frame/compat/f2c/bla_tbmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tbsv.h b/frame/compat/f2c/bla_tbsv.h index ce5ecba108..bf9ae74eb1 100644 --- a/frame/compat/f2c/bla_tbsv.h +++ b/frame/compat/f2c/bla_tbsv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tpmv.c b/frame/compat/f2c/bla_tpmv.c index 802c00c2eb..7a2849c7c3 100644 --- a/frame/compat/f2c/bla_tpmv.c +++ b/frame/compat/f2c/bla_tpmv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tpmv.h b/frame/compat/f2c/bla_tpmv.h index e6fb29db46..3dc6303150 100644 --- a/frame/compat/f2c/bla_tpmv.h +++ b/frame/compat/f2c/bla_tpmv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tpsv.c b/frame/compat/f2c/bla_tpsv.c index bc4e3f4d49..a6eabb94ff 100644 --- a/frame/compat/f2c/bla_tpsv.c +++ b/frame/compat/f2c/bla_tpsv.c @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_tpsv.h b/frame/compat/f2c/bla_tpsv.h index ce083e23a1..2613fc2c56 100644 --- a/frame/compat/f2c/bla_tpsv.h +++ b/frame/compat/f2c/bla_tpsv.h @@ -5,8 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/compat/f2c/bla_xerbla.c b/frame/compat/f2c/bla_xerbla.c index 0e0ec59d34..8453b4d99d 100644 --- a/frame/compat/f2c/bla_xerbla.c +++ b/frame/compat/f2c/bla_xerbla.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,7 @@ extern BLIS_THREAD_LOCAL rntm_t tl_rntm; /* Table of constant values */ -/* Subroutine */ int xerbla_blis_impl(const bla_character *srname, const bla_integer *info, ftnlen srname_len) +/* Subroutine */ void xerbla_blis_impl(const bla_character *srname, const bla_integer *info, ftnlen srname_len) { /* -- LAPACK auxiliary routine (preliminary version) -- */ /* Univ. of Tennessee, Univ. of California Berkeley, NAG Ltd., */ @@ -93,8 +93,20 @@ extern BLIS_THREAD_LOCAL rntm_t tl_rntm; bool print_on_error = bli_rntm_print_on_error( &global_rntm ); if (print_on_error) { - printf("** On entry to %6s, parameter number %2i had an illegal value\n", - srname, (int)*info); + // The check for -10 is specific to xerbla_()'s use-case in ?imatcopy_() APIs. + // The definition of an info value for memory failure could be abstracted + // to a higher layer, if needed. This would enable us to reuse xerbla_() + // with this specific info value, in case of encountering a memory allocation + // failure. + if( *info == -10 ) + { + printf("** On entry to %6s, memory allocation failed\n", srname); + } + else + { + printf("** On entry to %6s, parameter number %2i had an illegal value\n", + srname, (int)*info); + } } bool stop_on_error = bli_rntm_stop_on_error( &global_rntm ); @@ -105,15 +117,15 @@ extern BLIS_THREAD_LOCAL rntm_t tl_rntm; /* End of XERBLA */ - return 0; + return; } /* xerbla_blis_impl */ #ifdef BLIS_ENABLE_BLAS -/* Subroutine */ int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len) +/* Subroutine */ void PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len) { xerbla_blis_impl(srname, info, srname_len); - return 0; + return; } /* xerbla */ #endif diff --git a/frame/compat/f2c/bla_xerbla.h b/frame/compat/f2c/bla_xerbla.h index 72f9b7592d..7f0fb2d0db 100644 --- a/frame/compat/f2c/bla_xerbla.h +++ b/frame/compat/f2c/bla_xerbla.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,10 +33,10 @@ */ -BLIS_EXPORT_BLAS int xerbla_blis_impl(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +BLIS_EXPORT_BLAS void xerbla_blis_impl(const bla_character *srname, const bla_integer *info, ftnlen srname_len); #ifdef BLIS_ENABLE_BLAS -BLIS_EXPORT_BLAS int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +BLIS_EXPORT_BLAS void PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); #endif diff --git a/frame/compat/f2c/bla_xerbla_array.c b/frame/compat/f2c/bla_xerbla_array.c index 2521cd5d23..411de5de66 100644 --- a/frame/compat/f2c/bla_xerbla_array.c +++ b/frame/compat/f2c/bla_xerbla_array.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +37,7 @@ #define MAX_NUM_CHARS 32 -int xerbla_array_blis_impl(const bla_character *srname_array, const bla_integer srname_len, const bla_integer *info) +void xerbla_array_blis_impl(const bla_character *srname_array, const bla_integer srname_len, const bla_integer *info) { int i; #if 1 @@ -65,14 +66,15 @@ int xerbla_array_blis_impl(const bla_character *srname_array, const bla_integer // Call xerbla_(). PASTE_XERBLA( srname, info, ( ftnlen )srname_len ); - return 0; + return; } #ifdef BLIS_ENABLE_BLAS -int PASTEF770(xerbla_array)(const bla_character *srname_array, const bla_integer srname_len, const bla_integer *info) +void PASTEF770(xerbla_array)(const bla_character *srname_array, const bla_integer srname_len, const bla_integer *info) { - return xerbla_array_blis_impl(srname_array, srname_len, info); + xerbla_array_blis_impl(srname_array, srname_len, info); + return; } #endif diff --git a/frame/compat/f2c/bla_xerbla_array.h b/frame/compat/f2c/bla_xerbla_array.h index f007fadc1d..8ddb571ed7 100644 --- a/frame/compat/f2c/bla_xerbla_array.h +++ b/frame/compat/f2c/bla_xerbla_array.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,10 +33,10 @@ */ -BLIS_EXPORT_BLAS int xerbla_array_blis_impl(const bla_character *srname, const bla_integer srname_len, const bla_integer *info); +BLIS_EXPORT_BLAS void xerbla_array_blis_impl(const bla_character *srname, const bla_integer srname_len, const bla_integer *info); #ifdef BLIS_ENABLE_BLAS -BLIS_EXPORT_BLAS int PASTEF770(xerbla_array)(const bla_character *srname, const bla_integer srname_len, const bla_integer *info); +BLIS_EXPORT_BLAS void PASTEF770(xerbla_array)(const bla_character *srname, const bla_integer srname_len, const bla_integer *info); #endif diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 0b0107efab..2f497f1e76 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -62,6 +62,9 @@ CNTX_INIT_PROTS( penryn ) #endif // -- AMD64 architectures -- +#ifdef BLIS_CONFIG_ZEN5 +CNTX_INIT_PROTS( zen5 ) +#endif #ifdef BLIS_CONFIG_ZEN4 CNTX_INIT_PROTS( zen4 ) #endif @@ -95,6 +98,9 @@ CNTX_INIT_PROTS( armsve ) #ifdef BLIS_CONFIG_A64FX CNTX_INIT_PROTS( a64fx ) #endif +#ifdef BLIS_CONFIG_FIRESTORM +CNTX_INIT_PROTS( firestorm ) +#endif #ifdef BLIS_CONFIG_THUNDERX2 CNTX_INIT_PROTS( thunderx2 ) #endif @@ -177,6 +183,9 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_FAMILY_ZEN5 +#include "bli_family_zen5.h" +#endif #ifdef BLIS_FAMILY_ZEN4 #include "bli_family_zen4.h" #endif @@ -202,6 +211,14 @@ CNTX_INIT_PROTS( generic ) #include "bli_family_bulldozer.h" #endif +// -- ARM families -- +#ifdef BLIS_FAMILY_ARM64 +#include "bli_family_arm64.h" +#endif +#ifdef BLIS_FAMILY_ARM32 +#include "bli_family_arm32.h" +#endif + // -- ARM architectures -- #ifdef BLIS_FAMILY_ARMSVE @@ -210,6 +227,9 @@ CNTX_INIT_PROTS( generic ) #ifdef BLIS_FAMILY_A64FX #include "bli_family_a64fx.h" #endif +#ifdef BLIS_FAMILY_FIRESTORM +#include "bli_family_firestorm.h" +#endif #ifdef BLIS_FAMILY_THUNDERX2 #include "bli_family_thunderx2.h" #endif @@ -276,9 +296,15 @@ CNTX_INIT_PROTS( generic ) #endif // -- AMD64 architectures -- +#ifdef BLIS_KERNELS_ZEN5 +#include "bli_kernels_zen5.h" +#endif #ifdef BLIS_KERNELS_ZEN4 #include "bli_kernels_zen4.h" #endif +//#ifdef BLIS_KERNELS_ZEN3 +//#include "bli_kernels_zen3.h" +//#endif #ifdef BLIS_KERNELS_ZEN2 #include "bli_kernels_zen2.h" #endif diff --git a/frame/include/bli_arch_config_pre.h b/frame/include/bli_arch_config_pre.h index 1ab0561d83..86c5992306 100644 --- a/frame/include/bli_arch_config_pre.h +++ b/frame/include/bli_arch_config_pre.h @@ -69,7 +69,6 @@ void PASTEMAC2(cntx_init_,archname,BLIS_REF_SUFFIX) \ void PASTEMAC2(cntx_init_,archname,BLIS_IND_SUFFIX) \ ( \ ind_t method, \ - num_t dt, \ cntx_t* cntx \ ); diff --git a/frame/include/bli_blas_blis_impl_interface_defs.h b/frame/include/bli_blas_blis_impl_interface_defs.h new file mode 100644 index 0000000000..fc5e6a1d27 --- /dev/null +++ b/frame/include/bli_blas_blis_impl_interface_defs.h @@ -0,0 +1,366 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_BLAS_INTERFACE_DEFS_H +#define BLIS_BLAS_INTERFACE_DEFS_H + +#ifdef BLIS_ENABLE_NO_UNDERSCORE_API +#ifdef BLIS_ENABLE_BLAS + +#define isamax_blis_impl_ isamax_blis_impl +#define idamax_blis_impl_ idamax_blis_impl +#define icamax_blis_impl_ icamax_blis_impl +#define izamax_blis_impl_ izamax_blis_impl +#define sasum_blis_impl_ sasum_blis_impl +#define dasum_blis_impl_ dasum_blis_impl +#define scasum_blis_impl_ scasum_blis_impl +#define dzasum_blis_impl_ dzasum_blis_impl +#define saxpy_blis_impl_ saxpy_blis_impl +#define daxpy_blis_impl_ daxpy_blis_impl +#define caxpy_blis_impl_ caxpy_blis_impl +#define zaxpy_blis_impl_ zaxpy_blis_impl +#define scopy_blis_impl_ scopy_blis_impl +#define dcopy_blis_impl_ dcopy_blis_impl +#define ccopy_blis_impl_ ccopy_blis_impl +#define zcopy_blis_impl_ zcopy_blis_impl +#define sdot_blis_impl_ sdot_blis_impl +#define ddot_blis_impl_ ddot_blis_impl +#define cdotc_blis_impl_ cdotc_blis_impl +#define zdotc_blis_impl_ zdotc_blis_impl +#define cdotu_blis_impl_ cdotu_blis_impl +#define zdotu_blis_impl_ zdotu_blis_impl +#define snrm2_blis_impl_ snrm2_blis_impl +#define dnrm2_blis_impl_ dnrm2_blis_impl +#define scnrm2_blis_impl_ scnrm2_blis_impl +#define dznrm2_blis_impl_ dznrm2_blis_impl +#define sscal_blis_impl_ sscal_blis_impl +#define dscal_blis_impl_ dscal_blis_impl +#define cscal_blis_impl_ cscal_blis_impl +#define csscal_blis_impl_ csscal_blis_impl +#define zscal_blis_impl_ zscal_blis_impl +#define zdscal_blis_impl_ zdscal_blis_impl +#define sswap_blis_impl_ sswap_blis_impl +#define dswap_blis_impl_ dswap_blis_impl +#define cswap_blis_impl_ cswap_blis_impl +#define zswap_blis_impl_ zswap_blis_impl +#define sgemv_blis_impl_ sgemv_blis_impl +#define dgemv_blis_impl_ dgemv_blis_impl +#define cgemv_blis_impl_ cgemv_blis_impl +#define zgemv_blis_impl_ zgemv_blis_impl +#define sger_blis_impl_ sger_blis_impl +#define dger_blis_impl_ dger_blis_impl +#define cgerc_blis_impl_ cgerc_blis_impl +#define cgeru_blis_impl_ cgeru_blis_impl +#define zgerc_blis_impl_ zgerc_blis_impl +#define zgeru_blis_impl_ zgeru_blis_impl +#define chemv_blis_impl_ chemv_blis_impl +#define zhemv_blis_impl_ zhemv_blis_impl +#define cher_blis_impl_ cher_blis_impl +#define zher_blis_impl_ zher_blis_impl +#define cher2_blis_impl_ cher2_blis_impl +#define zher2_blis_impl_ zher2_blis_impl +#define ssymv_blis_impl_ ssymv_blis_impl +#define dsymv_blis_impl_ dsymv_blis_impl +#define csymm_blis_impl_ csymm_blis_impl +#define zsymm_blis_impl_ zsymm_blis_impl +#define ssyr_blis_impl_ ssyr_blis_impl +#define dsyr_blis_impl_ dsyr_blis_impl +#define csyrk_blis_impl_ csyrk_blis_impl +#define csyrk_blis_impl_ csyrk_blis_impl +#define zsyrk_blis_impl_ zsyrk_blis_impl +#define ssyr2_blis_impl_ ssyr2_blis_impl +#define dsyr2_blis_impl_ dsyr2_blis_impl +#define csyr2k_blis_impl_ csyr2k_blis_impl +#define zsyr2k_blis_impl_ zsyr2k_blis_impl +#define strmv_blis_impl_ strmv_blis_impl +#define dtrmv_blis_impl_ dtrmv_blis_impl +#define ctrmv_blis_impl_ ctrmv_blis_impl +#define ztrmv_blis_impl_ ztrmv_blis_impl +#define strsv_blis_impl_ strsv_blis_impl +#define dtrsv_blis_impl_ dtrsv_blis_impl +#define ctrsv_blis_impl_ ctrsv_blis_impl +#define ztrsv_blis_impl_ ztrsv_blis_impl +#define sgemm_blis_impl_ sgemm_blis_impl +#define dgemm_blis_impl_ dgemm_blis_impl +#define cgemm_blis_impl_ cgemm_blis_impl +#define zgemm_blis_impl_ zgemm_blis_impl +#define chemm_blis_impl_ chemm_blis_impl +#define zhemm_blis_impl_ zhemm_blis_impl +#define dgemmt_blis_impl_ dgemmt_blis_impl +#define sgemmt_blis_impl_ sgemmt_blis_impl +#define zgemmt_blis_impl_ zgemmt_blis_impl +#define cgemmt_blis_impl_ cgemmt_blis_impl +#define sgemm_batch_blis_impl_ sgemm_batch_blis_impl +#define dgemm_batch_blis_impl_ dgemm_batch_blis_impl +#define cgemm_batch_blis_impl_ cgemm_batch_blis_impl +#define zgemm_batch_blis_impl_ zgemm_batch_blis_impl +#define sgemm_compute_blis_impl_ sgemm_compute_blis_impl +#define dgemm_compute_blis_impl_ dgemm_compute_blis_impl +#define sgemm_pack_get_size_blis_impl_ sgemm_pack_get_size_blis_impl +#define dgemm_pack_get_size_blis_impl_ dgemm_pack_get_size_blis_impl +#define sgemm_pack_blis_impl_ sgemm_pack_blis_impl +#define dgemm_pack_blis_impl_ dgemm_pack_blis_impl +#define saxpby_blis_impl_ saxpby_blis_impl +#define daxpby_blis_impl_ daxpby_blis_impl +#define caxpby_blis_impl_ caxpby_blis_impl +#define zaxpby_blis_impl_ zaxpby_blis_impl +#define cher2k_blis_impl_ cher2k_blis_impl +#define zher2k_blis_impl_ zher2k_blis_impl +#define cherk_blis_impl_ cherk_blis_impl +#define zherk_blis_impl_ zherk_blis_impl +#define ssymm_blis_impl_ ssymm_blis_impl +#define dsymm_blis_impl_ dsymm_blis_impl +#define ssyr2k_blis_impl_ ssyr2k_blis_impl +#define dsyr2k_blis_impl_ dsyr2k_blis_impl +#define ssyrk_blis_impl_ ssyrk_blis_impl +#define dsyrk_blis_impl_ dsyrk_blis_impl +#define strmm_blis_impl_ strmm_blis_impl +#define dtrmm_blis_impl_ dtrmm_blis_impl +#define ctrmm_blis_impl_ ctrmm_blis_impl +#define ztrmm_blis_impl_ ztrmm_blis_impl +#define strsm_blis_impl_ strsm_blis_impl +#define dtrsm_blis_impl_ dtrsm_blis_impl +#define ctrsm_blis_impl_ ctrsm_blis_impl +#define ztrsm_blis_impl_ ztrsm_blis_impl +#define lsame_blis_impl_ lsame_blis_impl + +#endif // BLIS_ENABLE_BLAS +#endif // BLIS_ENABLE_NO_UNDERSCORE_API + +#ifdef BLIS_ENABLE_UPPERCASE_API +#ifdef BLIS_ENABLE_BLAS + +#define caxpby_blis_impl CAXPBY_BLIS_IMPL +#define caxpy_blis_impl CAXPY_BLIS_IMPL +#define ccopy_blis_impl CCOPY_BLIS_IMPL +#define cdotc_blis_impl CDOTC_BLIS_IMPL +#define cdotcsub_blis_impl CDOTCSUB_BLIS_IMPL +#define cdotu_blis_impl CDOTU_BLIS_IMPL +#define cdotusub_blis_impl CDOTUSUB_BLIS_IMPL +#define cgbmv_blis_impl CGBMV_BLIS_IMPL +#define cgemm_blis_impl CGEMM_BLIS_IMPL +#define cgemm3m_blis_impl CGEMM3M_BLIS_IMPL +#define cgemm_batch_blis_impl CGEMM_BATCH_BLIS_IMPL +#define cgemmt_blis_impl CGEMMT_BLIS_IMPL +#define cgemv_blis_impl CGEMV_BLIS_IMPL +#define cgerc_blis_impl CGERC_BLIS_IMPL +#define cgeru_blis_impl CGERU_BLIS_IMPL +#define chbmv_blis_impl CHBMV_BLIS_IMPL +#define chemm_blis_impl CHEMM_BLIS_IMPL +#define chemv_blis_impl CHEMV_BLIS_IMPL +#define cher_blis_impl CHER_BLIS_IMPL +#define cher2_blis_impl CHER2_BLIS_IMPL +#define cher2k_blis_impl CHER2K_BLIS_IMPL +#define cherk_blis_impl CHERK_BLIS_IMPL +#define chpmv_blis_impl CHPMV_BLIS_IMPL +#define chpr_blis_impl CHPR_BLIS_IMPL +#define chpr2_blis_impl CHPR2_BLIS_IMPL +#define crotg_blis_impl CROTG_BLIS_IMPL +#define cscal_blis_impl CSCAL_BLIS_IMPL +#define csrot_blis_impl CSROT_BLIS_IMPL +#define csscal_blis_impl CSSCAL_BLIS_IMPL +#define cswap_blis_impl CSWAP_BLIS_IMPL +#define csymm_blis_impl CSYMM_BLIS_IMPL +#define csyr2k_blis_impl CSYR2K_BLIS_IMPL +#define csyrk_blis_impl CSYRK_BLIS_IMPL +#define ctbmv_blis_impl CTBMV_BLIS_IMPL +#define ctbsv_blis_impl CTBSV_BLIS_IMPL +#define ctpmv_blis_impl CTPMV_BLIS_IMPL +#define ctpsv_blis_impl CTPSV_BLIS_IMPL +#define ctrmm_blis_impl CTRMM_BLIS_IMPL +#define ctrmv_blis_impl CTRMV_BLIS_IMPL +#define ctrsm_blis_impl CTRSM_BLIS_IMPL +#define ctrsv_blis_impl CTRSV_BLIS_IMPL +#define dasum_blis_impl DASUM_BLIS_IMPL +#define dasumsub_blis_impl DASUMSUB_BLIS_IMPL +#define daxpby_blis_impl DAXPBY_BLIS_IMPL +#define daxpy_blis_impl DAXPY_BLIS_IMPL +#define dcabs1_blis_impl DCABS1_BLIS_IMPL +#define dcopy_blis_impl DCOPY_BLIS_IMPL +#define ddot_blis_impl DDOT_BLIS_IMPL +#define ddotsub_blis_impl DDOTSUB_BLIS_IMPL +#define dgbmv_blis_impl DGBMV_BLIS_IMPL +#define dgemm_blis_impl DGEMM_BLIS_IMPL +#define dgemm_batch_blis_impl DGEMM_BATCH_BLIS_IMPL +#define dgemm_compute_blis_impl DGEMM_COMPUTE_BLIS_IMPL +#define dgemm_pack_get_size_blis_impl DGEMM_PACK_GET_SIZE_BLIS_IMPL +#define dgemm_pack_blis_impl DGEMM_PACK_BLIS_IMPL +#define dgemmt_blis_impl DGEMMT_BLIS_IMPL +#define dgemv_blis_impl DGEMV_BLIS_IMPL +#define dger_blis_impl DGER_BLIS_IMPL +#define dnrm2_blis_impl DNRM2_BLIS_IMPL +#define dnrm2sub_blis_impl DNRM2SUB_BLIS_IMPL +#define drot_blis_impl DROT_BLIS_IMPL +#define drotg_blis_impl DROTG_BLIS_IMPL +#define drotm_blis_impl DROTM_BLIS_IMPL +#define drotmg_blis_impl DROTMG_BLIS_IMPL +#define dsbmv_blis_impl DSBMV_BLIS_IMPL +#define dscal_blis_impl DSCAL_BLIS_IMPL +#define dsdot_blis_impl DSDOT_BLIS_IMPL +#define dsdotsub_blis_impl DSDOTSUB_BLIS_IMPL +#define dspmv_blis_impl DSPMV_BLIS_IMPL +#define dspr_blis_impl DSPR_BLIS_IMPL +#define dspr2_blis_impl DSPR2_BLIS_IMPL +#define dswap_blis_impl DSWAP_BLIS_IMPL +#define dsymm_blis_impl DSYMM_BLIS_IMPL +#define dsymv_blis_impl DSYMV_BLIS_IMPL +#define dsyr_blis_impl DSYR_BLIS_IMPL +#define dsyr2_blis_impl DSYR2_BLIS_IMPL +#define dsyr2k_blis_impl DSYR2K_BLIS_IMPL +#define dsyrk_blis_impl DSYRK_BLIS_IMPL +#define dtbmv_blis_impl DTBMV_BLIS_IMPL +#define dtbsv_blis_impl DTBSV_BLIS_IMPL +#define dtpmv_blis_impl DTPMV_BLIS_IMPL +#define dtpsv_blis_impl DTPSV_BLIS_IMPL +#define dtrmm_blis_impl DTRMM_BLIS_IMPL +#define dtrmv_blis_impl DTRMV_BLIS_IMPL +#define dtrsm_blis_impl DTRSM_BLIS_IMPL +#define dtrsv_blis_impl DTRSV_BLIS_IMPL +#define dzasum_blis_impl DZASUM_BLIS_IMPL +#define dzasumsub_blis_impl DZASUMSUB_BLIS_IMPL +#define dznrm2_blis_impl DZNRM2_BLIS_IMPL +#define dznrm2sub_blis_impl DZNRM2SUB_BLIS_IMPL +#define icamax_blis_impl ICAMAX_BLIS_IMPL +#define icamaxsub_blis_impl ICAMAXSUB_BLIS_IMPL +#define icamin_blis_impl ICAMIN_BLIS_IMPL +#define icaminsub_blis_impl ICAMINSUB_BLIS_IMPL +#define idamax_blis_impl IDAMAX_BLIS_IMPL +#define idamaxsub_blis_impl IDAMAXSUB_BLIS_IMPL +#define idamin_blis_impl IDAMIN_BLIS_IMPL +#define idaminsub_blis_impl IDAMINSUB_BLIS_IMPL +#define isamax_blis_impl ISAMAX_BLIS_IMPL +#define isamaxsub_blis_impl ISAMAXSUB_BLIS_IMPL +#define isamin_blis_impl ISAMIN_BLIS_IMPL +#define isaminsub_blis_impl ISAMINSUB_BLIS_IMPL +#define izamax_blis_impl IZAMAX_BLIS_IMPL +#define izamaxsub_blis_impl IZAMAXSUB_BLIS_IMPL +#define izamin_blis_impl IZAMIN_BLIS_IMPL +#define izaminsub_blis_impl IZAMINSUB_BLIS_IMPL +#define lsame_blis_impl LSAME_BLIS_IMPL +#define sasum_blis_impl SASUM_BLIS_IMPL +#define sasumsub_blis_impl SASUMSUB_BLIS_IMPL +#define saxpby_blis_impl SAXPBY_BLIS_IMPL +#define saxpy_blis_impl SAXPY_BLIS_IMPL +#define scabs1_blis_impl SCABS1_BLIS_IMPL +#define scasum_blis_impl SCASUM_BLIS_IMPL +#define scasumsub_blis_impl SCASUMSUB_BLIS_IMPL +#define scnrm2_blis_impl SCNRM2_BLIS_IMPL +#define scnrm2sub_blis_impl SCNRM2SUB_BLIS_IMPL +#define scopy_blis_impl SCOPY_BLIS_IMPL +#define sdot_blis_impl SDOT_BLIS_IMPL +#define sdotsub_blis_impl SDOTSUB_BLIS_IMPL +#define sdsdot_blis_impl SDSDOT_BLIS_IMPL +#define sdsdotsub_blis_impl SDSDOTSUB_BLIS_IMPL +#define sgbmv_blis_impl SGBMV_BLIS_IMPL +#define sgemm_blis_impl SGEMM_BLIS_IMPL +#define sgemm_batch_blis_impl SGEMM_BATCH_BLIS_IMPL +#define sgemm_compute_blis_impl SGEMM_COMPUTE_BLIS_IMPL +#define sgemm_pack_get_size_blis_impl SGEMM_PACK_GET_SIZE_BLIS_IMPL +#define sgemm_pack_blis_impl SGEMM_PACK_BLIS_IMPL +#define sgemmt_blis_impl SGEMMT_BLIS_IMPL +#define sgemv_blis_impl SGEMV_BLIS_IMPL +#define sger_blis_impl SGER_BLIS_IMPL +#define snrm2_blis_impl SNRM2_BLIS_IMPL +#define snrm2sub_blis_impl SNRM2SUB_BLIS_IMPL +#define srot_blis_impl SROT_BLIS_IMPL +#define srotg_blis_impl SROTG_BLIS_IMPL +#define srotm_blis_impl SROTM_BLIS_IMPL +#define srotmg_blis_impl SROTMG_BLIS_IMPL +#define ssbmv_blis_impl SSBMV_BLIS_IMPL +#define sscal_blis_impl SSCAL_BLIS_IMPL +#define sspmv_blis_impl SSPMV_BLIS_IMPL +#define sspr_blis_impl SSPR_BLIS_IMPL +#define sspr2_blis_impl SSPR2_BLIS_IMPL +#define sswap_blis_impl SSWAP_BLIS_IMPL +#define ssymm_blis_impl SSYMM_BLIS_IMPL +#define ssymv_blis_impl SSYMV_BLIS_IMPL +#define ssyr_blis_impl SSYR_BLIS_IMPL +#define ssyr2_blis_impl SSYR2_BLIS_IMPL +#define ssyr2k_blis_impl SSYR2K_BLIS_IMPL +#define ssyrk_blis_impl SSYRK_BLIS_IMPL +#define stbmv_blis_impl STBMV_BLIS_IMPL +#define stbsv_blis_impl STBSV_BLIS_IMPL +#define stpmv_blis_impl STPMV_BLIS_IMPL +#define stpsv_blis_impl STPSV_BLIS_IMPL +#define strmm_blis_impl STRMM_BLIS_IMPL +#define strmv_blis_impl STRMV_BLIS_IMPL +#define strsm_blis_impl STRSM_BLIS_IMPL +#define strsv_blis_impl STRSV_BLIS_IMPL +#define xerbla_blis_impl XERBLA_BLIS_IMPL +#define zaxpby_blis_impl ZAXPBY_BLIS_IMPL +#define zaxpy_blis_impl ZAXPY_BLIS_IMPL +#define zcopy_blis_impl ZCOPY_BLIS_IMPL +#define zdotc_blis_impl ZDOTC_BLIS_IMPL +#define zdotcsub_blis_impl ZDOTCSUB_BLIS_IMPL +#define zdotu_blis_impl ZDOTU_BLIS_IMPL +#define zdotusub_blis_impl ZDOTUSUB_BLIS_IMPL +#define zdrot_blis_impl ZDROT_BLIS_IMPL +#define zdscal_blis_impl ZDSCAL_BLIS_IMPL +#define zgbmv_blis_impl ZGBMV_BLIS_IMPL +#define zgemm_blis_impl ZGEMM_BLIS_IMPL +#define zgemm3m_blis_impl ZGEMM3M_BLIS_IMPL +#define zgemm_batch_blis_impl ZGEMM_BATCH_BLIS_IMPL +#define zgemmt_blis_impl ZGEMMT_BLIS_IMPL +#define zgemv_blis_impl ZGEMV_BLIS_IMPL +#define zgerc_blis_impl ZGERC_BLIS_IMPL +#define zgeru_blis_impl ZGERU_BLIS_IMPL +#define zhbmv_blis_impl ZHBMV_BLIS_IMPL +#define zhemm_blis_impl ZHEMM_BLIS_IMPL +#define zhemv_blis_impl ZHEMV_BLIS_IMPL +#define zher_blis_impl ZHER_BLIS_IMPL +#define zher2_blis_impl ZHER2_BLIS_IMPL +#define zher2k_blis_impl ZHER2K_BLIS_IMPL +#define zherk_blis_impl ZHERK_BLIS_IMPL +#define zhpmv_blis_impl ZHPMV_BLIS_IMPL +#define zhpr_blis_impl ZHPR_BLIS_IMPL +#define zhpr2_blis_impl ZHPR2_BLIS_IMPL +#define zrotg_blis_impl ZROTG_BLIS_IMPL +#define zscal_blis_impl ZSCAL_BLIS_IMPL +#define zswap_blis_impl ZSWAP_BLIS_IMPL +#define zsymm_blis_impl ZSYMM_BLIS_IMPL +#define zsyr2k_blis_impl ZSYR2K_BLIS_IMPL +#define zsyrk_blis_impl ZSYRK_BLIS_IMPL +#define ztbmv_blis_impl ZTBMV_BLIS_IMPL +#define ztbsv_blis_impl ZTBSV_BLIS_IMPL +#define ztpmv_blis_impl ZTPMV_BLIS_IMPL +#define ztpsv_blis_impl ZTPSV_BLIS_IMPL +#define ztrmm_blis_impl ZTRMM_BLIS_IMPL +#define ztrmv_blis_impl ZTRMV_BLIS_IMPL +#define ztrsm_blis_impl ZTRSM_BLIS_IMPL +#define ztrsv_blis_impl ZTRSV_BLIS_IMPL + +#endif // BLIS_ENABLE_BLAS +#endif // BLIS_ENABLE_UPPERCASE_API + +#endif diff --git a/frame/include/bli_blas_interface_defs.h b/frame/include/bli_blas_interface_defs.h new file mode 100644 index 0000000000..3f872fa675 --- /dev/null +++ b/frame/include/bli_blas_interface_defs.h @@ -0,0 +1,400 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_BLAS_INTERFACE_DEFS_H +#define BLIS_BLAS_INTERFACE_DEFS_H + +#ifdef BLIS_ENABLE_NO_UNDERSCORE_API +#ifdef BLIS_ENABLE_BLAS + +#define isamax_ isamax +#define idamax_ idamax +#define icamax_ icamax +#define izamax_ izamax +#define sasum_ sasum +#define dasum_ dasum +#define scasum_ scasum +#define dzasum_ dzasum +#define saxpy_ saxpy +#define daxpy_ daxpy +#define caxpy_ caxpy +#define zaxpy_ zaxpy +#define scopy_ scopy +#define dcopy_ dcopy +#define ccopy_ ccopy +#define zcopy_ zcopy +#define sdot_ sdot +#define ddot_ ddot +#define cdotc_ cdotc +#define zdotc_ zdotc +#define cdotu_ cdotu +#define zdotu_ zdotu +#define snrm2_ snrm2 +#define dnrm2_ dnrm2 +#define scnrm2_ scnrm2 +#define dznrm2_ dznrm2 +#define sscal_ sscal +#define dscal_ dscal +#define cscal_ cscal +#define csscal_ csscal +#define zscal_ zscal +#define zdscal_ zdscal +#define sswap_ sswap +#define dswap_ dswap +#define cswap_ cswap +#define zswap_ zswap +#define sgemv_ sgemv +#define dgemv_ dgemv +#define cgemv_ cgemv +#define zgemv_ zgemv +#define sger_ sger +#define dger_ dger +#define cgerc_ cgerc +#define cgeru_ cgeru +#define zgerc_ zgerc +#define zgeru_ zgeru +#define chemv_ chemv +#define zhemv_ zhemv +#define cher_ cher +#define zher_ zher +#define cher2_ cher2 +#define zher2_ zher2 +#define ssymv_ ssymv +#define dsymv_ dsymv +#define csymm_ csymm +#define zsymm_ zsymm +#define ssyr_ ssyr +#define dsyr_ dsyr +#define csyrk_ csyrk +#define csyrk_ csyrk +#define zsyrk_ zsyrk +#define ssyr2_ ssyr2 +#define dsyr2_ dsyr2 +#define csyr2k_ csyr2k +#define zsyr2k_ zsyr2k +#define strmv_ strmv +#define dtrmv_ dtrmv +#define ctrmv_ ctrmv +#define ztrmv_ ztrmv +#define strsv_ strsv +#define dtrsv_ dtrsv +#define ctrsv_ ctrsv +#define ztrsv_ ztrsv +#define sgemm_ sgemm +#define dgemm_ dgemm +#define cgemm_ cgemm +#define zgemm_ zgemm +#define chemm_ chemm +#define zhemm_ zhemm +#define dgemmt_ dgemmt +#define sgemmt_ sgemmt +#define zgemmt_ zgemmt +#define cgemmt_ cgemmt +#define sgemm_batch_ sgemm_batch +#define dgemm_batch_ dgemm_batch +#define cgemm_batch_ cgemm_batch +#define zgemm_batch_ zgemm_batch +#define sgemm_compute_ sgemm_compute +#define dgemm_compute_ dgemm_compute +#define sgemm_pack_get_size_ sgemm_pack_get_size +#define dgemm_pack_get_size_ dgemm_pack_get_size +#define sgemm_pack_ sgemm_pack +#define dgemm_pack_ dgemm_pack +#define saxpby_ saxpby +#define daxpby_ daxpby +#define caxpby_ caxpby +#define zaxpby_ zaxpby +#define cher2k_ cher2k +#define zher2k_ zher2k +#define cherk_ cherk +#define zherk_ zherk +#define ssymm_ ssymm +#define dsymm_ dsymm +#define ssyr2k_ ssyr2k +#define dsyr2k_ dsyr2k +#define ssyrk_ ssyrk +#define dsyrk_ dsyrk +#define strmm_ strmm +#define dtrmm_ dtrmm +#define ctrmm_ ctrmm +#define ztrmm_ ztrmm +#define strsm_ strsm +#define dtrsm_ dtrsm +#define ctrsm_ ctrsm +#define ztrsm_ ztrsm +#define lsame_ lsame + +#define cimatcopy_ cimatcopy +#define comatadd_ comatadd +#define comatcopy2_ comatcopy2 +#define comatcopy_ comatcopy +#define dimatcopy_ dimatcopy +#define domatadd_ domatadd +#define domatcopy2_ domatcopy2 +#define domatcopy_ domatcopy +#define simatcopy_ simatcopy +#define somatadd_ somatadd +#define somatcopy2_ somatcopy2 +#define somatcopy_ somatcopy +#define zimatcopy_ zimatcopy +#define zomatadd_ zomatadd +#define zomatcopy2_ zomatcopy2 +#define zomatcopy_ zomatcopy + +#endif // BLIS_ENABLE_BLAS +#endif // BLIS_ENABLE_NO_UNDERSCORE_API + +#ifdef BLIS_ENABLE_UPPERCASE_API +#ifdef BLIS_ENABLE_BLAS + +#define caxpby CAXPBY +#define caxpy CAXPY +#define ccopy CCOPY +#define cdotc CDOTC +#define cdotcsub CDOTCSUB +#define cdotu CDOTU +#define cdotusub CDOTUSUB +#define cgbmv CGBMV +#define cgemm CGEMM +#define cgemm3m CGEMM3M +#define cgemm_batch CGEMM_BATCH +#define cgemmt CGEMMT +#define cgemv CGEMV +#define cgerc CGERC +#define cgeru CGERU +#define chbmv CHBMV +#define chemm CHEMM +#define chemv CHEMV +#define cher CHER +#define cher2 CHER2 +#define cher2k CHER2K +#define cherk CHERK +#define chpmv CHPMV +#define chpr CHPR +#define chpr2 CHPR2 +#define crotg CROTG +#define cscal CSCAL +#define csrot CSROT +#define csscal CSSCAL +#define cswap CSWAP +#define csymm CSYMM +#define csyr2k CSYR2K +#define csyrk CSYRK +#define ctbmv CTBMV +#define ctbsv CTBSV +#define ctpmv CTPMV +#define ctpsv CTPSV +#define ctrmm CTRMM +#define ctrmv CTRMV +#define ctrsm CTRSM +#define ctrsv CTRSV +#define dasum DASUM +#define dasumsub DASUMSUB +#define daxpby DAXPBY +#define daxpy DAXPY +#define dcabs1 DCABS1 +#define dcopy DCOPY +#define ddot DDOT +#define ddotsub DDOTSUB +#define dgbmv DGBMV +#define dgemm DGEMM +#define dgemm_batch DGEMM_BATCH +#define dgemm_compute DGEMM_COMPUTE +#define dgemm_pack_get_size DGEMM_PACK_GET_SIZE +#define dgemm_pack DGEMM_PACK +#define dgemmt DGEMMT +#define dgemv DGEMV +#define dger DGER +#define dnrm2 DNRM2 +#define dnrm2sub DNRM2SUB +#define drot DROT +#define drotg DROTG +#define drotm DROTM +#define drotmg DROTMG +#define dsbmv DSBMV +#define dscal DSCAL +#define dsdot DSDOT +#define dsdotsub DSDOTSUB +#define dspmv DSPMV +#define dspr DSPR +#define dspr2 DSPR2 +#define dswap DSWAP +#define dsymm DSYMM +#define dsymv DSYMV +#define dsyr DSYR +#define dsyr2 DSYR2 +#define dsyr2k DSYR2K +#define dsyrk DSYRK +#define dtbmv DTBMV +#define dtbsv DTBSV +#define dtpmv DTPMV +#define dtpsv DTPSV +#define dtrmm DTRMM +#define dtrmv DTRMV +#define dtrsm DTRSM +#define dtrsv DTRSV +#define dzasum DZASUM +#define dzasumsub DZASUMSUB +#define dznrm2 DZNRM2 +#define dznrm2sub DZNRM2SUB +#define icamax ICAMAX +#define icamaxsub ICAMAXSUB +#define icamin ICAMIN +#define icaminsub ICAMINSUB +#define idamax IDAMAX +#define idamaxsub IDAMAXSUB +#define idamin IDAMIN +#define idaminsub IDAMINSUB +#define isamax ISAMAX +#define isamaxsub ISAMAXSUB +#define isamin ISAMIN +#define isaminsub ISAMINSUB +#define izamax IZAMAX +#define izamaxsub IZAMAXSUB +#define izamin IZAMIN +#define izaminsub IZAMINSUB +#define lsame LSAME +#define sasum SASUM +#define sasumsub SASUMSUB +#define saxpby SAXPBY +#define saxpy SAXPY +#define scabs1 SCABS1 +#define scasum SCASUM +#define scasumsub SCASUMSUB +#define scnrm2 SCNRM2 +#define scnrm2sub SCNRM2SUB +#define scopy SCOPY +#define sdot SDOT +#define sdotsub SDOTSUB +#define sdsdot SDSDOT +#define sdsdotsub SDSDOTSUB +#define sgbmv SGBMV +#define sgemm SGEMM +#define sgemm_batch SGEMM_BATCH +#define sgemm_compute SGEMM_COMPUTE +#define sgemm_pack_get_size SGEMM_PACK_GET_SIZE +#define sgemm_pack SGEMM_PACK +#define sgemmt SGEMMT +#define sgemv SGEMV +#define sger SGER +#define snrm2 SNRM2 +#define snrm2sub SNRM2SUB +#define srot SROT +#define srotg SROTG +#define srotm SROTM +#define srotmg SROTMG +#define ssbmv SSBMV +#define sscal SSCAL +#define sspmv SSPMV +#define sspr SSPR +#define sspr2 SSPR2 +#define sswap SSWAP +#define ssymm SSYMM +#define ssymv SSYMV +#define ssyr SSYR +#define ssyr2 SSYR2 +#define ssyr2k SSYR2K +#define ssyrk SSYRK +#define stbmv STBMV +#define stbsv STBSV +#define stpmv STPMV +#define stpsv STPSV +#define strmm STRMM +#define strmv STRMV +#define strsm STRSM +#define strsv STRSV +#define xerbla XERBLA +#define zaxpby ZAXPBY +#define zaxpy ZAXPY +#define zcopy ZCOPY +#define zdotc ZDOTC +#define zdotcsub ZDOTCSUB +#define zdotu ZDOTU +#define zdotusub ZDOTUSUB +#define zdrot ZDROT +#define zdscal ZDSCAL +#define zgbmv ZGBMV +#define zgemm ZGEMM +#define zgemm3m ZGEMM3M +#define zgemm_batch ZGEMM_BATCH +#define zgemmt ZGEMMT +#define zgemv ZGEMV +#define zgerc ZGERC +#define zgeru ZGERU +#define zhbmv ZHBMV +#define zhemm ZHEMM +#define zhemv ZHEMV +#define zher ZHER +#define zher2 ZHER2 +#define zher2k ZHER2K +#define zherk ZHERK +#define zhpmv ZHPMV +#define zhpr ZHPR +#define zhpr2 ZHPR2 +#define zrotg ZROTG +#define zscal ZSCAL +#define zswap ZSWAP +#define zsymm ZSYMM +#define zsyr2k ZSYR2K +#define zsyrk ZSYRK +#define ztbmv ZTBMV +#define ztbsv ZTBSV +#define ztpmv ZTPMV +#define ztpsv ZTPSV +#define ztrmm ZTRMM +#define ztrmv ZTRMV +#define ztrsm ZTRSM +#define ztrsv ZTRSV + +#define cimatcopy CIMATCOPY +#define comatadd COMATADD +#define comatcopy2 COMATCOPY2 +#define comatcopy COMATCOPY +#define dimatcopy DIMATCOPY +#define domatadd DOMATADD +#define domatcopy2 DOMATCOPY2 +#define domatcopy DOMATCOPY +#define simatcopy SIMATCOPY +#define somatadd SOMATADD +#define somatcopy2 SOMATCOPY2 +#define somatcopy SOMATCOPY +#define zimatcopy ZIMATCOPY +#define zomatadd ZOMATADD +#define zomatcopy2 ZOMATCOPY2 +#define zomatcopy ZOMATCOPY + +#endif // BLIS_ENABLE_BLAS +#endif // BLIS_ENABLE_UPPERCASE_API + +#endif diff --git a/frame/include/bli_error_macro_defs.h b/frame/include/bli_error_macro_defs.h index a0c9ea6ab3..00d8acdcb8 100644 --- a/frame/include/bli_error_macro_defs.h +++ b/frame/include/bli_error_macro_defs.h @@ -35,12 +35,6 @@ #ifndef BLIS_ERROR_MACRO_DEFS_H #define BLIS_ERROR_MACRO_DEFS_H -// -- Error-related macros -- - -// Used to determine the size of the array of error strings. -#define BLIS_MAX_NUM_ERR_MSGS 200 -#define BLIS_MAX_ERR_MSG_LENGTH 200 - // Used to insert filenames and line numbers into error-checking code. #define bli_check_error_code( code ) \ bli_check_error_code_helper( code, __FILE__, __LINE__ ) diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 9836819b98..940f0f2e85 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -152,6 +152,18 @@ GENTFUNCR2( scomplex, float, c, s, blasname, blisname ) \ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname ) +// -- Alternate three-operand macro (one char for complex, one for real proj +// for name, one for real proj for use) -- + + +#define INSERT_GENTFUNCR3_BLAS( blasname, blisname ) \ +\ +GENTFUNCR3( float, float, s, , s, blasname, blisname ) \ +GENTFUNCR3( double, double, d, , d, blasname, blisname ) \ +GENTFUNCR3( scomplex, float, c, s, s, blasname, blisname ) \ +GENTFUNCR3( dcomplex, double, z, d, d, blasname, blisname ) + + // -- Extended two-operand macro (used only for scal) -- #define INSERT_GENTFUNCSCAL_BLAS_C( blasname, blisname ) \ @@ -162,12 +174,12 @@ GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) #define INSERT_GENTFUNCSCAL_BLAS( blasname, blisname ) \ \ -GENTFUNCSCAL( float, float, s, , blasname, blisname ) \ -GENTFUNCSCAL( double, double, d, , blasname, blisname ) \ -GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \ -GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \ -GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \ -GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) +GENTFUNCSCAL( float, float, s, , s, blasname, blisname ) \ +GENTFUNCSCAL( double, double, d, , d, blasname, blisname ) \ +GENTFUNCSCAL( scomplex, scomplex, c, , c, blasname, blisname ) \ +GENTFUNCSCAL( dcomplex, dcomplex, z, , z, blasname, blisname ) \ +GENTFUNCSCAL( scomplex, float, c, s, s, blasname, blisname ) \ +GENTFUNCSCAL( dcomplex, double, z, d, d, blasname, blisname ) // --GEMMT specific kernels ---------------------------------------------------- @@ -187,6 +199,31 @@ GENTFUNC(scomplex, c, opname, u, funcname) \ GENTFUNC(dcomplex, z, opname, u, funcname) +#define INSERT_GENTFUNC_L_SDC( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, l, funcname) \ +GENTFUNC(double, d, opname, l, funcname) \ +GENTFUNC(scomplex, c, opname, l, funcname) + + +#define INSERT_GENTFUNC_U_SDC( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, u, funcname) \ +GENTFUNC(double, d, opname, u, funcname) \ +GENTFUNC(scomplex, c, opname, u, funcname) + +#define INSERT_GENTFUNC_L_SC( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, l, funcname) \ +GENTFUNC(scomplex, c, opname, l, funcname) + + +#define INSERT_GENTFUNC_U_SC( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, u, funcname) \ +GENTFUNC(scomplex, c, opname, u, funcname) + + // -- Macros for functions with one operand ------------------------------------ diff --git a/frame/include/bli_gentprot_macro_defs.h b/frame/include/bli_gentprot_macro_defs.h index 9321077b1f..1e6223224d 100644 --- a/frame/include/bli_gentprot_macro_defs.h +++ b/frame/include/bli_gentprot_macro_defs.h @@ -1,4 +1,3 @@ - /* BLIS @@ -6,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index f4fbeca63f..e35e48c8af 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -99,7 +99,7 @@ #endif // Macros to define names _blis_impl suffix, *_blis_impl is the blis -// blis implmenation of the respective API's which is invoked from CBLAS +// blis implementation of the respective API's which is invoked from CBLAS // and BLAS wrapper. #define PASTEF770S(name) name ## _blis_impl #define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl @@ -125,367 +125,9 @@ #include "bli_oapi_macro_defs.h" #include "bli_tapi_macro_defs.h" +// -- Include definitions for BLAS interfaces -#ifdef BLIS_ENABLE_NO_UNDERSCORE_API +#include "bli_blas_interface_defs.h" +#include "bli_blas_blis_impl_interface_defs.h" -#ifdef BLIS_ENABLE_BLAS -#define isamax_ isamax -#define idamax_ idamax -#define icamax_ icamax -#define izamax_ izamax -#define sasum_ sasum -#define dasum_ dasum -#define scasum_ scasum -#define dzasum_ dzasum -#define saxpy_ saxpy -#define daxpy_ daxpy -#define caxpy_ caxpy -#define zaxpy_ zaxpy -#define scopy_ scopy -#define dcopy_ dcopy -#define ccopy_ ccopy -#define zcopy_ zcopy -#define sdot_ sdot -#define ddot_ ddot -#define cdotc_ cdotc -#define zdotc_ zdotc -#define cdotu_ cdotu -#define zdotu_ zdotu -#define snrm2_ snrm2 -#define dnrm2_ dnrm2 -#define scnrm2_ scnrm2 -#define dznrm2_ dznrm2 -#define sscal_ sscal -#define dscal_ dscal -#define cscal_ cscal -#define csscal_ csscal -#define zscal_ zscal -#define zdscal_ zdscal -#define sswap_ sswap -#define dswap_ dswap -#define cswap_ cswap -#define zswap_ zswap -#define sgemv_ sgemv -#define dgemv_ dgemv -#define cgemv_ cgemv -#define zgemv_ zgemv -#define sger_ sger -#define dger_ dger -#define cgerc_ cgerc -#define cgeru_ cgeru -#define zgerc_ zgerc -#define zgeru_ zgeru -#define chemv_ chemv -#define zhemv_ zhemv -#define cher_ cher -#define zher_ zher -#define cher2_ cher2 -#define zher2_ zher2 -#define ssymv_ ssymv -#define dsymv_ dsymv -#define csymm_ csymm -#define zsymm_ zsymm -#define ssyr_ ssyr -#define dsyr_ dsyr -#define csyrk_ csyrk -#define csyrk_ csyrk -#define zsyrk_ zsyrk -#define ssyr2_ ssyr2 -#define dsyr2_ dsyr2 -#define csyr2k_ csyr2k -#define zsyr2k_ zsyr2k -#define strmv_ strmv -#define dtrmv_ dtrmv -#define ctrmv_ ctrmv -#define ztrmv_ ztrmv -#define strsv_ strsv -#define dtrsv_ dtrsv -#define ctrsv_ ctrsv -#define ztrsv_ ztrsv -#define sgemm_ sgemm -#define dgemm_ dgemm -#define cgemm_ cgemm -#define zgemm_ zgemm -#define chemm_ chemm -#define zhemm_ zhemm -#define dgemmt_ dgemmt -#define sgemmt_ sgemmt -#define zgemmt_ zgemmt -#define cgemmt_ cgemmt -#define sgemm_batch_ sgemm_batch -#define dgemm_batch_ dgemm_batch -#define cgemm_batch_ cgemm_batch -#define zgemm_batch_ zgemm_batch -#define sgemm_compute_ sgemm_compute -#define dgemm_compute_ dgemm_compute -#define sgemm_pack_get_size_ sgemm_pack_get_size -#define dgemm_pack_get_size_ dgemm_pack_get_size -#define sgemm_pack_ sgemm_pack -#define dgemm_pack_ dgemm_pack -#define saxpby_ saxpby -#define daxpby_ daxpby -#define caxpby_ caxpby -#define zaxpby_ zaxpby -#define cher2k_ cher2k -#define zher2k_ zher2k -#define cherk_ cherk -#define zherk_ zherk -#define ssymm_ ssymm -#define dsymm_ dsymm -#define ssyr2k_ ssyr2k -#define dsyr2k_ dsyr2k -#define ssyrk_ ssyrk -#define dsyrk_ dsyrk -#define strmm_ strmm -#define dtrmm_ dtrmm -#define ctrmm_ ctrmm -#define ztrmm_ ztrmm -#define strsm_ strsm -#define dtrsm_ dtrsm -#define ctrsm_ ctrsm -#define ztrsm_ ztrsm -#define lsame_ lsame - -#define cimatcopy_ cimatcopy -#define comatadd_ comatadd -#define comatcopy2_ comatcopy2 -#define comatcopy_ comatcopy -#define dimatcopy_ dimatcopy -#define domatadd_ domatadd -#define domatcopy2_ domatcopy2 -#define domatcopy_ domatcopy -#define simatcopy_ simatcopy -#define somatadd_ somatadd -#define somatcopy2_ somatcopy2 -#define somatcopy_ somatcopy -#define zimatcopy_ zimatcopy -#define zomatadd_ zomatadd -#define zomatcopy2_ zomatcopy2 -#define zomatcopy_ zomatcopy - -#endif // BLIS_ENABLE_BLAS -#endif // BLIS_ENABLE_NO_UNDERSCORE_API - - -#ifdef BLIS_ENABLE_UPPERCASE_API - -#ifdef BLIS_ENABLE_BLAS -#define caxpby CAXPBY -#define caxpy CAXPY -#define ccopy CCOPY -#define cdotc CDOTC -#define cdotcsub CDOTCSUB -#define cdotu CDOTU -#define cdotusub CDOTUSUB -#define cgbmv CGBMV -#define cgemm CGEMM -#define cgemm3m CGEMM3M -#define cgemm_batch CGEMM_BATCH -#define cgemmt CGEMMT -#define cgemv CGEMV -#define cgerc CGERC -#define cgeru CGERU -#define chbmv CHBMV -#define chemm CHEMM -#define chemv CHEMV -#define cher CHER -#define cher2 CHER2 -#define cher2k CHER2K -#define cherk CHERK -#define chpmv CHPMV -#define chpr CHPR -#define chpr2 CHPR2 -#define cimatcopy CIMATCOPY -#define comatadd COMATADD -#define comatcopy2 COMATCOPY2 -#define comatcopy COMATCOPY -#define crotg CROTG -#define cscal CSCAL -#define csrot CSROT -#define csscal CSSCAL -#define cswap CSWAP -#define csymm CSYMM -#define csyr2k CSYR2K -#define csyrk CSYRK -#define ctbmv CTBMV -#define ctbsv CTBSV -#define ctpmv CTPMV -#define ctpsv CTPSV -#define ctrmm CTRMM -#define ctrmv CTRMV -#define ctrsm CTRSM -#define ctrsv CTRSV -#define dasum DASUM -#define dasumsub DASUMSUB -#define daxpby DAXPBY -#define daxpy DAXPY -#define dcabs1 DCABS1 -#define dcopy DCOPY -#define ddot DDOT -#define ddotsub DDOTSUB -#define dgbmv DGBMV -#define dgemm DGEMM -#define dgemm_batch DGEMM_BATCH -#define dgemm_compute DGEMM_COMPUTE -#define dgemm_pack_get_size DGEMM_PACK_GET_SIZE -#define dgemm_pack DGEMM_PACK -#define dgemmt DGEMMT -#define dgemv DGEMV -#define dger DGER -#define dnrm2 DNRM2 -#define dnrm2sub DNRM2SUB -#define dimatcopy DIMATCOPY -#define domatadd DOMATADD -#define domatcopy2 DOMATCOPY2 -#define domatcopy DOMATCOPY -#define drot DROT -#define drotg DROTG -#define drotm DROTM -#define drotmg DROTMG -#define dsbmv DSBMV -#define dscal DSCAL -#define dsdot DSDOT -#define dsdotsub DSDOTSUB -#define dspmv DSPMV -#define dspr DSPR -#define dspr2 DSPR2 -#define dswap DSWAP -#define dsymm DSYMM -#define dsymv DSYMV -#define dsyr DSYR -#define dsyr2 DSYR2 -#define dsyr2k DSYR2K -#define dsyrk DSYRK -#define dtbmv DTBMV -#define dtbsv DTBSV -#define dtpmv DTPMV -#define dtpsv DTPSV -#define dtrmm DTRMM -#define dtrmv DTRMV -#define dtrsm DTRSM -#define dtrsv DTRSV -#define dzasum DZASUM -#define dzasumsub DZASUMSUB -#define dznrm2 DZNRM2 -#define dznrm2sub DZNRM2SUB -#define icamax ICAMAX -#define icamaxsub ICAMAXSUB -#define icamin ICAMIN -#define icaminsub ICAMINSUB -#define idamax IDAMAX -#define idamaxsub IDAMAXSUB -#define idamin IDAMIN -#define idaminsub IDAMINSUB -#define isamax ISAMAX -#define isamaxsub ISAMAXSUB -#define isamin ISAMIN -#define isaminsub ISAMINSUB -#define izamax IZAMAX -#define izamaxsub IZAMAXSUB -#define izamin IZAMIN -#define izaminsub IZAMINSUB -#define lsame LSAME -#define sasum SASUM -#define sasumsub SASUMSUB -#define saxpby SAXPBY -#define saxpy SAXPY -#define scabs1 SCABS1 -#define scasum SCASUM -#define scasumsub SCASUMSUB -#define scnrm2 SCNRM2 -#define scnrm2sub SCNRM2SUB -#define scopy SCOPY -#define sdot SDOT -#define sdotsub SDOTSUB -#define sdsdot SDSDOT -#define sdsdotsub SDSDOTSUB -#define sgbmv SGBMV -#define sgemm SGEMM -#define sgemm_batch SGEMM_BATCH -#define sgemm_compute SGEMM_COMPUTE -#define sgemm_pack_get_size SGEMM_PACK_GET_SIZE -#define sgemm_pack SGEMM_PACK -#define sgemmt SGEMMT -#define sgemv SGEMV -#define sger SGER -#define snrm2 SNRM2 -#define snrm2sub SNRM2SUB -#define simatcopy SIMATCOPY -#define somatadd SOMATADD -#define somatcopy2 SOMATCOPY2 -#define somatcopy SOMATCOPY -#define srot SROT -#define srotg SROTG -#define srotm SROTM -#define srotmg SROTMG -#define ssbmv SSBMV -#define sscal SSCAL -#define sspmv SSPMV -#define sspr SSPR -#define sspr2 SSPR2 -#define sswap SSWAP -#define ssymm SSYMM -#define ssymv SSYMV -#define ssyr SSYR -#define ssyr2 SSYR2 -#define ssyr2k SSYR2K -#define ssyrk SSYRK -#define stbmv STBMV -#define stbsv STBSV -#define stpmv STPMV -#define stpsv STPSV -#define strmm STRMM -#define strmv STRMV -#define strsm STRSM -#define strsv STRSV -#define xerbla XERBLA -#define zaxpby ZAXPBY -#define zaxpy ZAXPY -#define zcopy ZCOPY -#define zdotc ZDOTC -#define zdotcsub ZDOTCSUB -#define zdotu ZDOTU -#define zdotusub ZDOTUSUB -#define zdrot ZDROT -#define zdscal ZDSCAL -#define zgbmv ZGBMV -#define zgemm ZGEMM -#define zgemm3m ZGEMM3M -#define zgemm_batch ZGEMM_BATCH -#define zgemmt ZGEMMT -#define zgemv ZGEMV -#define zgerc ZGERC -#define zgeru ZGERU -#define zhbmv ZHBMV -#define zhemm ZHEMM -#define zhemv ZHEMV -#define zher ZHER -#define zher2 ZHER2 -#define zher2k ZHER2K -#define zherk ZHERK -#define zhpmv ZHPMV -#define zhpr ZHPR -#define zhpr2 ZHPR2 -#define zimatcopy ZIMATCOPY -#define zomatadd ZOMATADD -#define zomatcopy2 ZOMATCOPY2 -#define zomatcopy ZOMATCOPY -#define zrotg ZROTG -#define zscal ZSCAL -#define zswap ZSWAP -#define zsymm ZSYMM -#define zsyr2k ZSYR2K -#define zsyrk ZSYRK -#define ztbmv ZTBMV -#define ztbsv ZTBSV -#define ztpmv ZTPMV -#define ztpsv ZTPSV -#define ztrmm ZTRMM -#define ztrmv ZTRMV -#define ztrsm ZTRSM -#define ztrsv ZTRSV #endif - -#endif // BLIS_ENABLE_BLAS -#endif // BLIS_ENABLE_UPPERCASE_API - diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index eae4619807..dc13f93c5a 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -1000,50 +1000,6 @@ BLIS_INLINE bool bli_is_panel_packed( pack_t schema ) ( schema & BLIS_PACK_PANEL_BIT ); } -BLIS_INLINE bool bli_is_4mi_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_4MI ); -} - -BLIS_INLINE bool bli_is_3mi_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_3MI ); -} - -BLIS_INLINE bool bli_is_3ms_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_3MS ); -} - -BLIS_INLINE bool bli_is_ro_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_RO ); -} - -BLIS_INLINE bool bli_is_io_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_IO ); -} - -BLIS_INLINE bool bli_is_rpi_packed( pack_t schema ) -{ - return ( bool ) - ( ( schema & BLIS_PACK_FORMAT_BITS ) == BLIS_BITVAL_RPI ); -} - -BLIS_INLINE bool bli_is_rih_packed( pack_t schema ) -{ - return ( bool ) - ( bli_is_ro_packed( schema ) || - bli_is_io_packed( schema ) || - bli_is_rpi_packed( schema ) ); -} - BLIS_INLINE bool bli_is_1r_packed( pack_t schema ) { return ( bool ) @@ -1082,20 +1038,6 @@ BLIS_INLINE guint_t bli_pack_schema_index( pack_t schema ) } - -// pointer-related - -// Increment a pointer by an integer fraction: -// p0 + (num/dem) -// where p0 is a pointer to a datatype of size sizeof_p0. -BLIS_INLINE void_fp bli_ptr_inc_by_frac( void_fp p0, siz_t sizeof_p0, dim_t num, dim_t den ) -{ - return ( void_fp ) - ( ( char* )p0 + ( ( num * ( dim_t )sizeof_p0 ) / den ) ); -} - - - // Set dimensions, increments, effective uplo/diagoff, etc for ONE matrix // argument. diff --git a/frame/include/bli_scalar_macro_defs.h b/frame/include/bli_scalar_macro_defs.h index f8c3996430..293c80f910 100644 --- a/frame/include/bli_scalar_macro_defs.h +++ b/frame/include/bli_scalar_macro_defs.h @@ -206,37 +206,6 @@ #include "bli_set0bbs_mxn.h" -// -- 3m-specific scalar macros -- - -#include "bli_copyri3s.h" -#include "bli_copyjri3s.h" - -#include "bli_scal2ri3s.h" -#include "bli_scal2jri3s.h" - -#include "bli_scal2ri3s_mxn.h" - - -// -- 4mh/3mh-specific scalar macros -- - -// ro -#include "bli_scal2ros.h" -#include "bli_scal2jros.h" - -// io -#include "bli_scal2ios.h" -#include "bli_scal2jios.h" - -// rpi -#include "bli_scal2rpis.h" -#include "bli_scal2jrpis.h" - -#include "bli_scal2rihs_mxn.h" -#include "bli_scal2rihs_mxn_diag.h" -#include "bli_scal2rihs_mxn_uplo.h" -#include "bli_setrihs_mxn_diag.h" - - // -- 1m-specific scalar macros -- // 1e diff --git a/frame/include/bli_trsm_small_ref.h b/frame/include/bli_trsm_small_ref.h index 715db884e3..3a23e1ee98 100644 --- a/frame/include/bli_trsm_small_ref.h +++ b/frame/include/bli_trsm_small_ref.h @@ -1,3 +1,37 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + #ifdef BLIS_ENABLE_TRSM_PREINVERSION #define DIAG_ELE_INV_OPS(a, b) (a / b) #define DIAG_ELE_EVAL_OPS(a, b) (a * b) diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index e3355e8432..4417124b3a 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -258,24 +258,10 @@ typedef void (*free_ft) ( void* p ); - 1 0000 01: packed by columns - 1 0000 10: packed by row panels - 1 0000 11: packed by column panels - - 1 0001 10: packed by 4m interleaved row panels - - 1 0001 11: packed by 4m interleaved column panels - - 1 0010 10: packed by 3m interleaved row panels - - 1 0010 11: packed by 3m interleaved column panels - - 1 0011 10: packed by 4m separated row panels (not used) - - 1 0011 11: packed by 4m separated column panels (not used) - - 1 0100 10: packed by 3m separated row panels - - 1 0100 11: packed by 3m separated column panels - - 1 0101 10: packed real-only row panels - - 1 0101 11: packed real-only column panels - - 1 0110 10: packed imag-only row panels - - 1 0110 11: packed imag-only column panels - - 1 0111 10: packed real+imag row panels - - 1 0111 11: packed real+imag column panels - - 1 1000 10: packed by 1m expanded row panels - - 1 1000 11: packed by 1m expanded column panels - - 1 1001 10: packed by 1m reordered row panels - - 1 1001 11: packed by 1m reordered column panels + - 1 0001 10: packed by 1m expanded row panels + - 1 0001 11: packed by 1m expanded column panels + - 1 0010 10: packed by 1m reordered row panels + - 1 0010 11: packed by 1m reordered column panels 23 Packed panel order if upper-stored - 0 == forward order if upper - 1 == reverse order if upper @@ -413,34 +399,13 @@ typedef void (*free_ft) ( void* p ); #define BLIS_BITVAL_UNIT_DIAG BLIS_UNIT_DIAG_BIT #define BLIS_BITVAL_INVERT_DIAG BLIS_INVERT_DIAG_BIT #define BLIS_BITVAL_NOT_PACKED 0x0 -#define BLIS_BITVAL_4MI ( 0x1 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_3MI ( 0x2 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_4MS ( 0x3 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_3MS ( 0x4 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_RO ( 0x5 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_IO ( 0x6 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_RPI ( 0x7 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_1E ( 0x8 << BLIS_PACK_FORMAT_SHIFT ) -#define BLIS_BITVAL_1R ( 0x9 << BLIS_PACK_FORMAT_SHIFT ) +#define BLIS_BITVAL_1E ( 0x1 << BLIS_PACK_FORMAT_SHIFT ) +#define BLIS_BITVAL_1R ( 0x2 << BLIS_PACK_FORMAT_SHIFT ) #define BLIS_BITVAL_PACKED_UNSPEC ( BLIS_PACK_BIT ) #define BLIS_BITVAL_PACKED_ROWS ( BLIS_PACK_BIT ) #define BLIS_BITVAL_PACKED_COLUMNS ( BLIS_PACK_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS ( BLIS_PACK_BIT | BLIS_PACK_PANEL_BIT ) #define BLIS_BITVAL_PACKED_COL_PANELS ( BLIS_PACK_BIT | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_4MI ( BLIS_PACK_BIT | BLIS_BITVAL_4MI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_4MI ( BLIS_PACK_BIT | BLIS_BITVAL_4MI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_3MI ( BLIS_PACK_BIT | BLIS_BITVAL_3MI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_3MI ( BLIS_PACK_BIT | BLIS_BITVAL_3MI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_4MS ( BLIS_PACK_BIT | BLIS_BITVAL_4MS | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_4MS ( BLIS_PACK_BIT | BLIS_BITVAL_4MS | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_3MS ( BLIS_PACK_BIT | BLIS_BITVAL_3MS | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_3MS ( BLIS_PACK_BIT | BLIS_BITVAL_3MS | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_RO ( BLIS_PACK_BIT | BLIS_BITVAL_RO | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_RO ( BLIS_PACK_BIT | BLIS_BITVAL_RO | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_IO ( BLIS_PACK_BIT | BLIS_BITVAL_IO | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_IO ( BLIS_PACK_BIT | BLIS_BITVAL_IO | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) -#define BLIS_BITVAL_PACKED_ROW_PANELS_RPI ( BLIS_PACK_BIT | BLIS_BITVAL_RPI | BLIS_PACK_PANEL_BIT ) -#define BLIS_BITVAL_PACKED_COL_PANELS_RPI ( BLIS_PACK_BIT | BLIS_BITVAL_RPI | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS_1E ( BLIS_PACK_BIT | BLIS_BITVAL_1E | BLIS_PACK_PANEL_BIT ) #define BLIS_BITVAL_PACKED_COL_PANELS_1E ( BLIS_PACK_BIT | BLIS_BITVAL_1E | BLIS_PACK_PANEL_BIT | BLIS_PACK_RC_BIT ) #define BLIS_BITVAL_PACKED_ROW_PANELS_1R ( BLIS_PACK_BIT | BLIS_BITVAL_1R | BLIS_PACK_PANEL_BIT ) @@ -553,20 +518,6 @@ typedef enum BLIS_PACKED_COLUMNS = BLIS_BITVAL_PACKED_COLUMNS, BLIS_PACKED_ROW_PANELS = BLIS_BITVAL_PACKED_ROW_PANELS, BLIS_PACKED_COL_PANELS = BLIS_BITVAL_PACKED_COL_PANELS, - BLIS_PACKED_ROW_PANELS_4MI = BLIS_BITVAL_PACKED_ROW_PANELS_4MI, - BLIS_PACKED_COL_PANELS_4MI = BLIS_BITVAL_PACKED_COL_PANELS_4MI, - BLIS_PACKED_ROW_PANELS_3MI = BLIS_BITVAL_PACKED_ROW_PANELS_3MI, - BLIS_PACKED_COL_PANELS_3MI = BLIS_BITVAL_PACKED_COL_PANELS_3MI, - BLIS_PACKED_ROW_PANELS_4MS = BLIS_BITVAL_PACKED_ROW_PANELS_4MS, - BLIS_PACKED_COL_PANELS_4MS = BLIS_BITVAL_PACKED_COL_PANELS_4MS, - BLIS_PACKED_ROW_PANELS_3MS = BLIS_BITVAL_PACKED_ROW_PANELS_3MS, - BLIS_PACKED_COL_PANELS_3MS = BLIS_BITVAL_PACKED_COL_PANELS_3MS, - BLIS_PACKED_ROW_PANELS_RO = BLIS_BITVAL_PACKED_ROW_PANELS_RO, - BLIS_PACKED_COL_PANELS_RO = BLIS_BITVAL_PACKED_COL_PANELS_RO, - BLIS_PACKED_ROW_PANELS_IO = BLIS_BITVAL_PACKED_ROW_PANELS_IO, - BLIS_PACKED_COL_PANELS_IO = BLIS_BITVAL_PACKED_COL_PANELS_IO, - BLIS_PACKED_ROW_PANELS_RPI = BLIS_BITVAL_PACKED_ROW_PANELS_RPI, - BLIS_PACKED_COL_PANELS_RPI = BLIS_BITVAL_PACKED_COL_PANELS_RPI, BLIS_PACKED_ROW_PANELS_1E = BLIS_BITVAL_PACKED_ROW_PANELS_1E, BLIS_PACKED_COL_PANELS_1E = BLIS_BITVAL_PACKED_COL_PANELS_1E, BLIS_PACKED_ROW_PANELS_1R = BLIS_BITVAL_PACKED_ROW_PANELS_1R, @@ -574,10 +525,8 @@ typedef enum } pack_t; // We combine row and column packing into one "type", and we start -// with BLIS_PACKED_ROW_PANELS, _COLUMN_PANELS. We also count the -// schema pair for "4ms" (4m separated), because its bit value has -// been reserved, even though we don't use it. -#define BLIS_NUM_PACK_SCHEMA_TYPES 10 +// with BLIS_PACKED_ROW_PANELS, _COLUMN_PANELS. +#define BLIS_NUM_PACK_SCHEMA_TYPES 3 // -- Pack order type -- @@ -670,12 +619,7 @@ typedef enum typedef enum { - BLIS_3MH = 0, - BLIS_3M1, - BLIS_4MH, - BLIS_4M1B, - BLIS_4M1A, - BLIS_1M, + BLIS_1M = 0, BLIS_NAT, BLIS_IND_FIRST = 0, BLIS_IND_LAST = BLIS_NAT @@ -683,13 +627,8 @@ typedef enum #define BLIS_NUM_IND_METHODS (BLIS_NAT+1) -// These are used in bli_*_oapi.c to construct the ind_t values from +// These are used in bli_l3_*_oapi.c to construct the ind_t values from // the induced method substrings that go into function names. -#define bli_3mh BLIS_3MH -#define bli_3m1 BLIS_3M1 -#define bli_4mh BLIS_4MH -#define bli_4mb BLIS_4M1B -#define bli_4m1 BLIS_4M1A #define bli_1m BLIS_1M #define bli_nat BLIS_NAT @@ -1023,6 +962,7 @@ typedef enum BLIS_ARCH_PENRYN, // AMD + BLIS_ARCH_ZEN5, BLIS_ARCH_ZEN4, BLIS_ARCH_ZEN3, BLIS_ARCH_ZEN2, @@ -1035,6 +975,7 @@ typedef enum // ARM BLIS_ARCH_ARMSVE, BLIS_ARCH_A64FX, + BLIS_ARCH_FIRESTORM, BLIS_ARCH_THUNDERX2, BLIS_ARCH_CORTEXA57, BLIS_ARCH_CORTEXA53, @@ -1063,6 +1004,10 @@ typedef enum // Default model BLIS_MODEL_DEFAULT, + // AMD Zen5 + BLIS_MODEL_TURIN, + BLIS_MODEL_TURIN_DENSE, + // AMD Zen4 BLIS_MODEL_GENOA, BLIS_MODEL_BERGAMO, @@ -1249,9 +1194,6 @@ typedef struct inc_t ps_a; inc_t ps_b; - // The type to convert to on output. - //num_t dt_on_output; - } auxinfo_t; @@ -1574,9 +1516,6 @@ typedef struct cntx_s func_t unpackm_kers[ BLIS_NUM_UNPACKM_KERS ]; ind_t method; - pack_t schema_a_block; - pack_t schema_b_panel; - pack_t schema_c_panel; } cntx_t; diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index a039361a1d..2d81842e90 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -147,6 +147,7 @@ #define ALIGN8 ".p2align 3 \n\t" #define ALIGN16 ".p2align 4 \n\t" #define ALIGN32 ".p2align 5 \n\t" +#define ALIGN64 ".p2align 6 \n\t" #endif @@ -1153,11 +1154,13 @@ // Conversions +#define CVTSI2SD(_0, _1) INSTR_(cvtsi2sd, _0, _1) #define CVTSS2SD(_0, _1) INSTR_(cvtss2sd, _0, _1) #define CVTSD2SS(_0, _1) INSTR_(cvtsd2ss, _0, _1) #define CVTPS2PD(_0, _1) INSTR_(cvtps2pd, _0, _1) #define CVTPD2PS(_0, _1) INSTR_(cvtpd2ps, _0, _1) +#define cvtsi2sd(_0, _1) CVTSI2SD(_0, _1) #define cvtss2sd(_0, _1) CVTSS2SD(_0, _1) #define cvtsd2ss(_0, _1) CVTSD2SS(_0, _1) #define cvtps2pd(_0, _1) CVTPS2PD(_0, _1) diff --git a/frame/include/blis.h b/frame/include/blis.h index 28174a4bba..f44fffaeae 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -200,13 +200,7 @@ extern "C" { // -- addon definitions -- -// NOTE: These definitions should not be included much earlier since an addon -// may wish to utilize other types and definitions provided by BLIS. -// TODO: Disable addon header file inclusion for windows since configure -// script is not executed, and subsequently the header file ie not generated. -#if !defined(_WIN32) && !defined(__CYGWIN__) #include "bli_addon.h" -#endif // -- sandbox implementation -- diff --git a/frame/include/level0/io/bli_scal2ios.h b/frame/include/level0/old/io/bli_scal2ios.h similarity index 100% rename from frame/include/level0/io/bli_scal2ios.h rename to frame/include/level0/old/io/bli_scal2ios.h diff --git a/frame/include/level0/io/bli_scal2jios.h b/frame/include/level0/old/io/bli_scal2jios.h similarity index 100% rename from frame/include/level0/io/bli_scal2jios.h rename to frame/include/level0/old/io/bli_scal2jios.h diff --git a/frame/include/level0/ri3/bli_copyjri3s.h b/frame/include/level0/old/ri3/bli_copyjri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_copyjri3s.h rename to frame/include/level0/old/ri3/bli_copyjri3s.h diff --git a/frame/include/level0/ri3/bli_copyri3s.h b/frame/include/level0/old/ri3/bli_copyri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_copyri3s.h rename to frame/include/level0/old/ri3/bli_copyri3s.h diff --git a/frame/include/level0/ri3/bli_scal2jri3s.h b/frame/include/level0/old/ri3/bli_scal2jri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_scal2jri3s.h rename to frame/include/level0/old/ri3/bli_scal2jri3s.h diff --git a/frame/include/level0/ri3/bli_scal2ri3s.h b/frame/include/level0/old/ri3/bli_scal2ri3s.h similarity index 100% rename from frame/include/level0/ri3/bli_scal2ri3s.h rename to frame/include/level0/old/ri3/bli_scal2ri3s.h diff --git a/frame/include/level0/ri3/bli_scal2ri3s_mxn.h b/frame/include/level0/old/ri3/bli_scal2ri3s_mxn.h similarity index 100% rename from frame/include/level0/ri3/bli_scal2ri3s_mxn.h rename to frame/include/level0/old/ri3/bli_scal2ri3s_mxn.h diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn.h similarity index 100% rename from frame/include/level0/rih/bli_scal2rihs_mxn.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn.h diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn_diag.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn_diag.h similarity index 100% rename from frame/include/level0/rih/bli_scal2rihs_mxn_diag.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn_diag.h diff --git a/frame/include/level0/rih/bli_scal2rihs_mxn_uplo.h b/frame/include/level0/old/rih/bli_scal2rihs_mxn_uplo.h similarity index 100% rename from frame/include/level0/rih/bli_scal2rihs_mxn_uplo.h rename to frame/include/level0/old/rih/bli_scal2rihs_mxn_uplo.h diff --git a/frame/include/level0/rih/bli_setrihs_mxn_diag.h b/frame/include/level0/old/rih/bli_setrihs_mxn_diag.h similarity index 100% rename from frame/include/level0/rih/bli_setrihs_mxn_diag.h rename to frame/include/level0/old/rih/bli_setrihs_mxn_diag.h diff --git a/frame/include/level0/ro/bli_scal2jros.h b/frame/include/level0/old/ro/bli_scal2jros.h similarity index 100% rename from frame/include/level0/ro/bli_scal2jros.h rename to frame/include/level0/old/ro/bli_scal2jros.h diff --git a/frame/include/level0/ro/bli_scal2ros.h b/frame/include/level0/old/ro/bli_scal2ros.h similarity index 100% rename from frame/include/level0/ro/bli_scal2ros.h rename to frame/include/level0/old/ro/bli_scal2ros.h diff --git a/frame/include/level0/rpi/bli_scal2jrpis.h b/frame/include/level0/old/rpi/bli_scal2jrpis.h similarity index 100% rename from frame/include/level0/rpi/bli_scal2jrpis.h rename to frame/include/level0/old/rpi/bli_scal2jrpis.h diff --git a/frame/include/level0/rpi/bli_scal2rpis.h b/frame/include/level0/old/rpi/bli_scal2rpis.h similarity index 100% rename from frame/include/level0/rpi/bli_scal2rpis.h rename to frame/include/level0/old/rpi/bli_scal2rpis.h diff --git a/frame/ind/cntx/bli_cntx_ind_stage.c b/frame/ind/cntx/bli_cntx_ind_stage.c deleted file mode 100644 index b5c15d5d75..0000000000 --- a/frame/ind/cntx/bli_cntx_ind_stage.c +++ /dev/null @@ -1,148 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -typedef void (*cntx_stage_ft)( dim_t stage, cntx_t* cntx ); - -static void_fp bli_cntx_ind_stage_fp[BLIS_NUM_IND_METHODS] = -{ -/* 3mh */ bli_cntx_3mh_stage, -/* 3m1 */ bli_cntx_3m1_stage, -/* 4mh */ bli_cntx_4mh_stage, -/* 4mb */ bli_cntx_4mb_stage, -/* 4m1 */ bli_cntx_4m1_stage, -/* 1m */ bli_cntx_1m_stage, -/* nat */ bli_cntx_nat_stage -}; - - -// ----------------------------------------------------------------------------- - -// Execute the context initialization/finalization function associated -// with a given induced method. - -void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ) -{ - cntx_stage_ft func = bli_cntx_ind_stage_fp[ method ]; - - func( stage, cntx ); -} - -// ----------------------------------------------------------------------------- - -// These functions modify a context, if needed, for the particular "stage" of -// the induced method execution. Some induced methods do not make use of this -// feature. NOTE: ANY INDUCED METHOD THAT HAS A NON-EMPTY _stage() FUNCTION -// IS NOT THREAT-SAFE FOR APPLICATION-LEVEL THREADING. - -// ----------------------------------------------------------------------------- - -void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ) -{ - // Set the pack_t schemas as a function of the stage of execution. - if ( stage == 0 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } - else if ( stage == 1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else // if ( stage == 2 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RPI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RPI, cntx ); - } -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ) -{ - // Set the pack_t schemas as a function of the stage of execution. - if ( stage == 0 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } - else if ( stage == 1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else if ( stage == 2 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_RO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_IO, cntx ); - } - else // if ( stage == 3 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_IO, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_RO, cntx ); - } -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ) -{ -} - -// ----------------------------------------------------------------------------- - -void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ) -{ -} - diff --git a/frame/ind/cntx/bli_cntx_ind_stage.h b/frame/ind/cntx/bli_cntx_ind_stage.h deleted file mode 100644 index 124421665a..0000000000 --- a/frame/ind/cntx/bli_cntx_ind_stage.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ); - -void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ); - diff --git a/frame/ind/oapi/bli_l3_3m4m1m_oapi.c b/frame/ind/oapi/bli_l3_3m4m1m_oapi.c deleted file mode 100644 index 6cb5b71837..0000000000 --- a/frame/ind/oapi/bli_l3_3m4m1m_oapi.c +++ /dev/null @@ -1,443 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -// -- gemm/her2k/syr2k --------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* A temporary hack to easily specify the 1m algorithm (block-panel or - panel-block). */ \ -/* - if ( PASTEMAC(opname,imeth) == bli_gemm1m ) \ - { \ - bli_gemm1mbp( alpha, a, b, beta, c ); \ - return; \ - } \ - else if ( PASTEMAC(opname,imeth) == bli_gemm3m1 ) \ - { \ - bli_gemm1mpb( alpha, a, b, beta, c ); \ - return; \ - } \ -*/ \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( alpha, a, b, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// gemm -GENFRONT( gemm, gemm, 3mh, 3 ) -GENFRONT( gemm, gemm, 3m1, 1 ) -GENFRONT( gemm, gemm, 4mh, 4 ) -GENFRONT( gemm, gemm, 4mb, 1 ) -GENFRONT( gemm, gemm, 4m1, 1 ) -GENFRONT( gemm, gemm, 1m, 1 ) - -// her2k -GENFRONT( her2k, gemm, 3mh, 3 ) -GENFRONT( her2k, gemm, 3m1, 1 ) -GENFRONT( her2k, gemm, 4mh, 4 ) -//GENFRONT( her2k, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( her2k, gemm, 4m1, 1 ) -GENFRONT( her2k, gemm, 1m, 1 ) - -// syr2k -GENFRONT( syr2k, gemm, 3mh, 3 ) -GENFRONT( syr2k, gemm, 3m1, 1 ) -GENFRONT( syr2k, gemm, 4mh, 4 ) -//GENFRONT( syr2k, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( syr2k, gemm, 4m1, 1 ) -GENFRONT( syr2k, gemm, 1m, 1 ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// hemm -GENFRONT( hemm, gemm, 3mh, 3 ) -GENFRONT( hemm, gemm, 3m1, 1 ) -GENFRONT( hemm, gemm, 4mh, 4 ) -//GENFRONT( hemm, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( hemm, gemm, 4m1, 1 ) -GENFRONT( hemm, gemm, 1m, 1 ) - -// symm -GENFRONT( symm, gemm, 3mh, 3 ) -GENFRONT( symm, gemm, 3m1, 1 ) -GENFRONT( symm, gemm, 4mh, 4 ) -//GENFRONT( symm, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( symm, gemm, 4m1, 1 ) -GENFRONT( symm, gemm, 1m, 1 ) - -// trmm3 -GENFRONT( trmm3, gemm, 3mh, 3 ) -GENFRONT( trmm3, gemm, 3m1, 1 ) -GENFRONT( trmm3, gemm, 4mh, 4 ) -//GENFRONT( trmm3, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( trmm3, gemm, 4m1, 1 ) -GENFRONT( trmm3, gemm, 1m, 1 ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( c ); \ - obj_t* beta_use = beta; \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( c ) ) \ - { \ - PASTEMAC(opname,nat)( alpha, a, beta, c, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* 3mh and 4mh change the context for each stage, and so in order to - remain thread-safe, we must make a local copy of the context for - those induced methods. */ \ - cntx_t cntx_l; \ - if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* For multi-stage methods, use BLIS_ONE as beta after the first - stage. */ \ - if ( i > 0 ) beta_use = &BLIS_ONE; \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( alpha, a, beta_use, c, cntx, rntm, NULL ); \ - } \ -} - -// herk -GENFRONT( herk, gemm, 3mh, 3 ) -GENFRONT( herk, gemm, 3m1, 1 ) -GENFRONT( herk, gemm, 4mh, 4 ) -//GENFRONT( herk, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( herk, gemm, 4m1, 1 ) -GENFRONT( herk, gemm, 1m, 1 ) - -// syrk -GENFRONT( syrk, gemm, 3mh, 3 ) -GENFRONT( syrk, gemm, 3m1, 1 ) -GENFRONT( syrk, gemm, 4mh, 4 ) -//GENFRONT( syrk, gemm, 4mb, 1 ) // Not implemented. -GENFRONT( syrk, gemm, 4m1, 1 ) -GENFRONT( syrk, gemm, 1m, 1 ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( b ); \ -\ - dim_t i; \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( b ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Some induced methods execute in multiple "stages". */ \ - for ( i = 0; i < nstage; ++i ) \ - { \ - /* Prepare the context for the ith stage of computation. */ \ - bli_cntx_ind_stage( ind, i, cntx ); \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, cntx, rntm, NULL ); \ - } \ -} - -// trmm -//GENFRONT( trmm, gemm, 3mh, 3 ) // Unimplementable. -GENFRONT( trmm, gemm, 3m1, 1 ) -//GENFRONT( trmm, gemm, 4mh, 4 ) // Unimplementable. -//GENFRONT( trmm, gemm, 4mb, 1 ) // Unimplementable. -GENFRONT( trmm, gemm, 4m1, 1 ) -GENFRONT( trmm, gemm, 1m, 1 ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth, nstage ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - ind_t ind = PASTEMAC0(imeth); \ - num_t dt = bli_obj_dt( b ); \ -\ - /* If the objects are in the real domain, execute the native - implementation. */ \ - if ( bli_obj_is_real( b ) ) \ - { \ - PASTEMAC(opname,nat)( side, alpha, a, b, cntx, rntm ); \ - return; \ - } \ -\ - /* Query a context for the current induced method. This context is - managed and cached by the gks and should not be freed by the caller. - Note that the datatype argument is needed because it will be passed - in when bli_gks_query_ind_cntx() eventually calls the induced method's - _cntx_init() function. */ \ - cntx = bli_gks_query_ind_cntx( ind, dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - { \ - /* NOTE: trsm cannot be implemented via any induced method that - needs to execute in stages (e.g. 3mh, 4mh). */ \ -\ - /* Invoke the operation's front end and request the default control - tree. */ \ - PASTEMAC(opname,_front)( side, alpha, a, b, cntx, rntm, NULL ); \ - } \ -} - -// trsm -//GENFRONT( trmm, trsm, 3mh, 3 ) // Unimplementable. -GENFRONT( trsm, trsm, 3m1, 1 ) -//GENFRONT( trmm, trsm, 4mh, 4 ) // Unimplementable. -//GENFRONT( trmm, trsm, 4mb, 1 ) // Unimplementable. -GENFRONT( trsm, trsm, 4m1, 1 ) -GENFRONT( trsm, trsm, 1m, 1 ) - diff --git a/frame/ind/oapi/bli_l3_ind_oapi.c b/frame/ind/oapi/bli_l3_ind_oapi.c deleted file mode 100644 index 95a7734c0e..0000000000 --- a/frame/ind/oapi/bli_l3_ind_oapi.c +++ /dev/null @@ -1,175 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -// -- gemm/her2k/syr2k --------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( alpha, a, b, beta, c, cntx, rntm ); \ -} - -GENFRONT( gemm, ind ) -GENFRONT( gemmt, ind ) -GENFRONT( her2k, ind ) -GENFRONT( syr2k, ind ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( side, alpha, a, b, beta, c, cntx, rntm ); \ -} - -GENFRONT( hemm, ind ) -GENFRONT( symm, ind ) -GENFRONT( trmm3, ind ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( c ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( alpha, a, beta, c, cntx, rntm ); \ -} - -GENFRONT( herk, ind ) -GENFRONT( syrk, ind ) - - -// -- trmm/trsm ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - num_t dt = bli_obj_dt( b ); \ - PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - func( side, alpha, a, b, cntx, rntm ); \ -} - -GENFRONT( trmm, ind ) -GENFRONT( trsm, ind ) - diff --git a/frame/ind/oapi/bli_l3_ind_oapi.h b/frame/ind/oapi/bli_l3_ind_oapi.h deleted file mode 100644 index 642bed39ff..0000000000 --- a/frame/ind/oapi/bli_l3_ind_oapi.h +++ /dev/null @@ -1,99 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - - -// -// Generate object-based prototypes for induced methods that work for -// trmm and trsm (ie: two-operand operations). -// -#undef GENPROT -#define GENPROT( imeth ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(gemmt,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(gemmt,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ - -GENPROT( nat ) -GENPROT( ind ) -GENPROT( 3m1 ) -GENPROT( 4m1 ) -GENPROT( 1m ) - - -// -// Generate object-based prototypes for induced methods that do NOT work -// for trmm and trsm (ie: two-operand operations). -// -#undef GENPROT_NO2OP -#define GENPROT_NO2OP( imeth ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); - -GENPROT_NO2OP( 3mh ) -GENPROT_NO2OP( 4mh ) -GENPROT_NO2OP( 4mb ) - - -// -// Generate object-based prototypes for 1m methods that specify an algorithm -// (e.g., block-panel or panel-block). -// - -/* -#undef GENPROT -#define GENPROT( imeth, alg ) \ -\ -BLIS_EXPORT_BLIS void PASTEMAC2(gemm,imeth,alg) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c ); \ -*/ - -//GENPROT( 1m, bp ) -//GENPROT( 1m, pb ) - diff --git a/frame/ind/oapi/bli_l3_nat_oapi.c b/frame/ind/oapi/bli_l3_nat_oapi.c deleted file mode 100644 index b41fa3e075..0000000000 --- a/frame/ind/oapi/bli_l3_nat_oapi.c +++ /dev/null @@ -1,243 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -// NOTE: The function definitions in this file can be consolidated with the -// definitions for the other induced methods. The only advantage of keeping -// them separate is that it allows us to avoid the very small loop overhead -// of executing one iteration of a for loop, plus the overhead of calling a -// function that does nothing (ie: the _cntx_init_stage() function). - -// -- gemm/her2k/syr2k/gemmt --------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ -} - -// If a sandbox was enabled, do not define bli_gemmnat() since it will be -// defined in the sandbox environment. -#ifndef BLIS_ENABLE_SANDBOX -GENFRONT( gemm, gemm, nat ) -GENFRONT( gemmt, gemm, nat ) -#endif -GENFRONT( her2k, gemm, nat ) -GENFRONT( syr2k, gemm, nat ) - - -// -- hemm/symm/trmm3 ---------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( hemm, gemm, nat ) -GENFRONT( symm, gemm, nat ) -GENFRONT( trmm3, gemm, nat ) - - -// -- herk/syrk ---------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, beta, c, cntx, rntm, NULL \ - ); \ -} - -GENFRONT( herk, gemm, nat ) -GENFRONT( syrk, gemm, nat ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, cntx, rntm, NULL \ - ); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ -} - -GENFRONT( trmm, gemm, nat ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - side_t side, \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ -\ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - side, alpha, a, b, cntx, rntm, NULL \ - ); \ -\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ -} - -GENFRONT( trsm, trsm, nat ) - diff --git a/frame/ind/tapi/bli_l3_ind_tapi.c b/frame/ind/tapi/bli_l3_ind_tapi.c deleted file mode 100644 index 9ca7746bc0..0000000000 --- a/frame/ind/tapi/bli_l3_ind_tapi.c +++ /dev/null @@ -1,664 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -// -- gemm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( gemm3mh ) -INSERT_GENTFUNC_BASIC0( gemm3m1 ) -INSERT_GENTFUNC_BASIC0( gemm4mh ) -INSERT_GENTFUNC_BASIC0( gemm4mb ) -INSERT_GENTFUNC_BASIC0( gemm4m1 ) -INSERT_GENTFUNC_BASIC0( gemm1m ) - - -// -- hemm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( hemm3mh ) -INSERT_GENTFUNC_BASIC0( hemm3m1 ) -INSERT_GENTFUNC_BASIC0( hemm4mh ) -INSERT_GENTFUNC_BASIC0( hemm4m1 ) -INSERT_GENTFUNC_BASIC0( hemm1m ) - - -// -- herk --------------------------------------------------------------------- - -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - dim_t m, \ - dim_t k, \ - ctype_r* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype_r* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt_r, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNCR_BASIC0( herk3mh ) -INSERT_GENTFUNCR_BASIC0( herk3m1 ) -INSERT_GENTFUNCR_BASIC0( herk4mh ) -INSERT_GENTFUNCR_BASIC0( herk4m1 ) -INSERT_GENTFUNCR_BASIC0( herk1m ) - - -// -- her2k -------------------------------------------------------------------- - -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype_r* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt_r = PASTEMAC(chr,type); \ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_HERMITIAN, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNCR_BASIC0( her2k3mh ) -INSERT_GENTFUNCR_BASIC0( her2k3m1 ) -INSERT_GENTFUNCR_BASIC0( her2k4mh ) -INSERT_GENTFUNCR_BASIC0( her2k4m1 ) -INSERT_GENTFUNCR_BASIC0( her2k1m ) - - -// -- symm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - conj_t conja, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_conj( conja, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( symm3mh ) -INSERT_GENTFUNC_BASIC0( symm3m1 ) -INSERT_GENTFUNC_BASIC0( symm4mh ) -INSERT_GENTFUNC_BASIC0( symm4m1 ) -INSERT_GENTFUNC_BASIC0( symm1m ) - - -// -- syrk --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, betao, co; \ -\ - dim_t m_a, n_a; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( syrk3mh ) -INSERT_GENTFUNC_BASIC0( syrk3m1 ) -INSERT_GENTFUNC_BASIC0( syrk4mh ) -INSERT_GENTFUNC_BASIC0( syrk4m1 ) -INSERT_GENTFUNC_BASIC0( syrk1m ) - - -// -- syr2k -------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - uplo_t uploc, \ - trans_t transa, \ - trans_t transb, \ - dim_t m, \ - dim_t k, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t m_a, n_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ - bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploc, &co ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_SYMMETRIC, &co ); \ -\ - PASTEMAC0(opname) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( syr2k3mh ) -INSERT_GENTFUNC_BASIC0( syr2k3m1 ) -INSERT_GENTFUNC_BASIC0( syr2k4mh ) -INSERT_GENTFUNC_BASIC0( syr2k4m1 ) -INSERT_GENTFUNC_BASIC0( syr2k1m ) - - -// -- trmm3 -------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - trans_t transb, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - ctype* beta, \ - ctype* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo, betao, co; \ -\ - dim_t mn_a; \ - dim_t m_b, n_b; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ - bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ - bli_obj_set_conjtrans( transb, &bo ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trmm33mh ) -INSERT_GENTFUNC_BASIC0( trmm33m1 ) -INSERT_GENTFUNC_BASIC0( trmm34mh ) -INSERT_GENTFUNC_BASIC0( trmm34m1 ) -INSERT_GENTFUNC_BASIC0( trmm31m ) - - -// -- trmm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trmm3m1 ) -INSERT_GENTFUNC_BASIC0( trmm4m1 ) -INSERT_GENTFUNC_BASIC0( trmm1m ) - - -// -- trsm --------------------------------------------------------------------- - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, opname ) \ -\ -void PASTEMAC(ch,opname) \ - ( \ - side_t side, \ - uplo_t uploa, \ - trans_t transa, \ - diag_t diaga, \ - dim_t m, \ - dim_t n, \ - ctype* alpha, \ - ctype* a, inc_t rs_a, inc_t cs_a, \ - ctype* b, inc_t rs_b, inc_t cs_b, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - const num_t dt = PASTEMAC(ch,type); \ -\ - obj_t alphao, ao, bo; \ -\ - dim_t mn_a; \ -\ - bli_set_dim_with_side( side, m, n, &mn_a ); \ -\ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ -\ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ -\ - bli_obj_set_uplo( uploa, &ao ); \ - bli_obj_set_diag( diaga, &ao ); \ - bli_obj_set_conjtrans( transa, &ao ); \ -\ - bli_obj_set_struc( BLIS_TRIANGULAR, &ao ); \ -\ - PASTEMAC0(opname) \ - ( \ - side, \ - &alphao, \ - &ao, \ - &bo, \ - cntx, \ - rntm \ - ); \ -} - -INSERT_GENTFUNC_BASIC0( trsm3m1 ) -INSERT_GENTFUNC_BASIC0( trsm4m1 ) -INSERT_GENTFUNC_BASIC0( trsm1m ) - diff --git a/frame/thread/bli_thrcomm.h b/frame/thread/bli_thrcomm.h index 0ea7b7531b..26ca5be311 100644 --- a/frame/thread/bli_thrcomm.h +++ b/frame/thread/bli_thrcomm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,10 +52,10 @@ BLIS_INLINE dim_t bli_thrcomm_num_threads( thrcomm_t* comm ) // Thread communicator prototypes. -thrcomm_t* bli_thrcomm_create( rntm_t* rntm, dim_t n_threads ); -void bli_thrcomm_free( rntm_t* rntm, thrcomm_t* comm ); -void bli_thrcomm_init( dim_t n_threads, thrcomm_t* comm ); -void bli_thrcomm_cleanup( thrcomm_t* comm ); +BLIS_EXPORT_BLIS thrcomm_t* bli_thrcomm_create( rntm_t* rntm, dim_t n_threads ); +BLIS_EXPORT_BLIS void bli_thrcomm_free( rntm_t* rntm, thrcomm_t* comm ); +BLIS_EXPORT_BLIS void bli_thrcomm_init( dim_t n_threads, thrcomm_t* comm ); +BLIS_EXPORT_BLIS void bli_thrcomm_cleanup( thrcomm_t* comm ); BLIS_EXPORT_BLIS void bli_thrcomm_barrier( dim_t thread_id, thrcomm_t* comm ); BLIS_EXPORT_BLIS void* bli_thrcomm_bcast( dim_t inside_id, void* to_send, thrcomm_t* comm ); diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 3f0f9a0a07..19db63b84b 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -114,7 +114,7 @@ void bli_thread_range_sub // In this function, we partition the space between all_start and // all_end into n_way partitions, each a multiple of block_factor - // with the exception of the one partition that recieves the + // with the exception of the one partition that receives the // "edge" case (if applicable). // // Here are examples of various thread partitionings, in units of @@ -1902,13 +1902,13 @@ void bli_thread_update_rntm_from_env // current status of global_rntm. Must do this every time, in case // global_rntm has been updated by blis-specific threading function calls. - // NOTE: We don't need to acquire the global_rntm_mutex here because this - // function is updating the thread local tl_rntm (not global_rntm). - bool auto_factor = FALSE; dim_t jc, pc, ic, jr, ir, nt; bool blis_mt; + // Acquire the mutex protecting global_rntm. + bli_pthread_mutex_lock( &global_rntm_mutex ); + // Extract threading data from global_rntm. nt = bli_rntm_num_threads( &global_rntm ); jc = bli_rntm_jc_ways( &global_rntm ); @@ -1918,6 +1918,9 @@ void bli_thread_update_rntm_from_env ir = bli_rntm_ir_ways( &global_rntm ); blis_mt = bli_rntm_blis_mt( &global_rntm ); + // Release the mutex protecting global_rntm. + bli_pthread_mutex_unlock( &global_rntm_mutex ); + #ifdef BLIS_ENABLE_MULTITHREADING // Environment variables BLIS_NUM_THREADS and BLIS_*_NT have been read diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index 0f67ab7cd0..b06ee8242c 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -6,8 +6,8 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 23, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index 0ee84b8e15..57f59b85e3 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,6 +33,9 @@ */ +#ifndef BLI_UTIL_H_ +#define BLI_UTIL_H_ + #include "bli_util_check.h" // Prototype object APIs (expert and non-expert). @@ -68,5 +71,11 @@ // and without underscore, lowercase without underscore. #include "bli_util_api_wrap.h" +// Header file define different formats of BLAS APIs- uppercase with +// and without underscore, lowercase without underscore. +#include "bli_util_api_wrap_blis_impl.h" + // Public interface for the progress feature #include "bli_util_progress.h" + +#endif // BLI_UTIL_H_ diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index f2521bd047..3af1f34024 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -779,6 +779,7 @@ void DZGEMM_( const f77_char *transa, const f77_char *transb, const f77_int *m, { dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } + void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); @@ -1853,19 +1854,19 @@ void STRSV_(const char *uplo,const char *trans,const char *diag,const f77_ strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } -int XERBLA(const char *srname,const f77_int *info, ftnlen n) +void XERBLA(const char *srname,const f77_int *info, ftnlen n) { - return xerbla_blis_impl( srname, info, n); + xerbla_blis_impl( srname, info, n); } -int XERBLA_(const char *srname,const f77_int *info, ftnlen n) +void XERBLA_(const char *srname,const f77_int *info, ftnlen n) { - return xerbla_blis_impl( srname, info, n); + xerbla_blis_impl( srname, info, n); } -int xerbla(const char *srname,const f77_int *info, ftnlen n) +void xerbla(const char *srname,const f77_int *info, ftnlen n) { - return xerbla_blis_impl( srname, info, n); + xerbla_blis_impl( srname, info, n); } void ZAXPY(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index f4a1d49492..0f71491dbc 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,6 +32,9 @@ */ +#ifndef BLI_UTIL_API_WRAP_H_ +#define BLI_UTIL_API_WRAP_H_ + #ifdef BLIS_ENABLE_BLAS // file define different formats of BLAS APIs- uppercase with @@ -322,7 +325,7 @@ BLIS_EXPORT_BLIS scomplex CDOTC(const f77_int* n, const scomplex* x, const f77 BLIS_EXPORT_BLIS scomplex cdotc(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); -BLIS_EXPORT_BLIS scomplex CDOTC_ (const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); +BLIS_EXPORT_BLIS scomplex CDOTC_(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); @@ -336,15 +339,15 @@ BLIS_EXPORT_BLIS scomplex CDOTU_(const f77_int* n, const scomplex* x, const f7 BLIS_EXPORT_BLIS dcomplex ZDOTC(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); -BLIS_EXPORT_BLIS dcomplex zdotc (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); +BLIS_EXPORT_BLIS dcomplex zdotc(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); -BLIS_EXPORT_BLIS dcomplex ZDOTC_ (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); +BLIS_EXPORT_BLIS dcomplex ZDOTC_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); BLIS_EXPORT_BLIS dcomplex ZDOTU(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); -BLIS_EXPORT_BLIS dcomplex zdotu (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); +BLIS_EXPORT_BLIS dcomplex zdotu(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); BLIS_EXPORT_BLIS dcomplex ZDOTU_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); @@ -1486,11 +1489,11 @@ BLIS_EXPORT_BLIS f77_int LSAME_(const char *ca, const char *cb, const f77_in -BLIS_EXPORT_BLIS int XERBLA(const char *srname, const f77_int *info, ftnlen n); +BLIS_EXPORT_BLIS void XERBLA(const char *srname, const f77_int *info, ftnlen n); -BLIS_EXPORT_BLIS int xerbla(const char *srname, const f77_int *info, ftnlen n); +BLIS_EXPORT_BLIS void xerbla(const char *srname, const f77_int *info, ftnlen n); -BLIS_EXPORT_BLIS int XERBLA_(const char *srname, const f77_int *info, ftnlen n); +BLIS_EXPORT_BLIS void XERBLA_(const char *srname, const f77_int *info, ftnlen n); @@ -1797,3 +1800,5 @@ BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols #endif #endif // BLIS_ENABLE_BLAS + +#endif // BLI_UTIL_API_WRAP_H_ diff --git a/frame/util/bli_util_api_wrap_blis_impl.c b/frame/util/bli_util_api_wrap_blis_impl.c new file mode 100644 index 0000000000..886d9500be --- /dev/null +++ b/frame/util/bli_util_api_wrap_blis_impl.c @@ -0,0 +1,3172 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// file define different formats of BLAS APIs- uppercase with +// and without underscore, lowercase without underscore. + +#include "blis.h" +#include "bli_util_api_wrap.h" + +// wrapper functions to support additional symbols +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API +void CAXPY_BLIS_IMPL(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_blis_impl( n, ca, cx, incx, cy, incy); +} + +void caxpy_blis_impl_(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_blis_impl( n, ca, cx, incx, cy, incy); +} + +void CAXPY_BLIS_IMPL_(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_blis_impl( n, ca, cx, incx, cy, incy); +} + +void CCOPY_BLIS_IMPL(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_blis_impl( n, cx, incx, cy, incy); +} + +void ccopy_blis_impl_(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_blis_impl( n, cx, incx, cy, incy); +} + +void CCOPY_BLIS_IMPL_(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_blis_impl( n, cx, incx, cy, incy); +} + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +scomplex CDOTC_BLIS_IMPL(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_blis_impl( n, x, incx, y, incy); +} + +scomplex cdotc_blis_impl_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_blis_impl( n, x, incx, y, incy); +} + +scomplex CDOTC_BLIS_IMPL_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_blis_impl( n, x, incx, y, incy); +} + +scomplex CDOTU_BLIS_IMPL(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_blis_impl( n, x, incx, y, incy); +} + +scomplex cdotu_blis_impl_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_blis_impl( n, x, incx, y, incy); +} + +scomplex CDOTU_BLIS_IMPL_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_blis_impl( n, x, incx, y, incy); +} + +dcomplex ZDOTC_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_blis_impl( n, x, incx, y, incy); +} + +dcomplex zdotc_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_blis_impl( n, x, incx, y, incy); +} + +dcomplex ZDOTC_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_blis_impl( n, x, incx, y, incy); +} + +dcomplex ZDOTU_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_blis_impl( n, x, incx, y, incy); +} + +dcomplex zdotu_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_blis_impl( n, x, incx, y, incy); +} + +dcomplex ZDOTU_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_blis_impl( n, x, incx, y, incy); +} +#else +void CDOTC_BLIS_IMPL(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_blis_impl( retval, n, cx, incx, cy, incy); +} + +void cdotc_blis_impl_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_blis_impl( retval, n, cx, incx, cy, incy); +} + +void CDOTC_BLIS_IMPL_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_blis_impl( retval, n, cx, incx, cy, incy); +} + +void CDOTU_BLIS_IMPL(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_blis_impl( retval, n, cx, incx, cy, incy); +} + +void cdotu_blis_impl_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_blis_impl( retval, n, cx, incx, cy, incy); +} + +void CDOTU_BLIS_IMPL_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_blis_impl( retval, n, cx, incx, cy, incy); +} + +void ZDOTC_BLIS_IMPL(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_blis_impl( retval, n, zx, incx, zy, incy); +} + +void zdotc_blis_impl_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_blis_impl( retval, n, zx, incx, zy, incy); +} + +void ZDOTC_BLIS_IMPL_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_blis_impl( retval, n, zx, incx, zy, incy); +} + +void ZDOTU_BLIS_IMPL(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_blis_impl( retval, n, zx, incx, zy, incy); +} + +void zdotu_blis_impl_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_blis_impl( retval, n, zx, incx, zy, incy); +} + +void ZDOTU_BLIS_IMPL_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_blis_impl( retval, n, zx, incx, zy, incy); +} +#endif + +void CGBMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void cgbmv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGBMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGEMM_BLIS_IMPL(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemm_blis_impl_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM_BLIS_IMPL_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void cgemv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGEMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGERC_BLIS_IMPL(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void cgerc_blis_impl_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERC_BLIS_IMPL_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERU_BLIS_IMPL(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void cgeru_blis_impl_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERU_BLIS_IMPL_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CHBMV_BLIS_IMPL(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void chbmv_blis_impl_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHBMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHEMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void chemm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHEMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHEMV_BLIS_IMPL(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void chemv_blis_impl_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHEMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHER_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void cher_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void CHER_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void CHER2_BLIS_IMPL(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void cher2_blis_impl_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void CHER2_BLIS_IMPL_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void CHER2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cher2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHER2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHERK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void cherk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CHERK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CHPMV_BLIS_IMPL(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void chpmv_blis_impl_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void CHPMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void CHPR_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void chpr_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void CHPR_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void CHPR2_BLIS_IMPL(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void chpr2_blis_impl_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void CHPR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void CROTG_BLIS_IMPL(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_blis_impl( ca, cb, c, s); +} + +void crotg_blis_impl_(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_blis_impl( ca, cb, c, s); +} + +void CROTG_BLIS_IMPL_(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_blis_impl( ca, cb, c, s); +} + +void CSCAL_BLIS_IMPL(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_blis_impl( n, ca, cx, incx); +} + +void cscal_blis_impl_(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_blis_impl( n, ca, cx, incx); +} + +void CSCAL_BLIS_IMPL_(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_blis_impl( n, ca, cx, incx); +} + +void CSROT_BLIS_IMPL(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void csrot_blis_impl_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void CSROT_BLIS_IMPL_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void CSSCAL_BLIS_IMPL(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_blis_impl( n, sa, cx, incx); +} + +void csscal_blis_impl_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_blis_impl( n, sa, cx, incx); +} + +void CSSCAL_BLIS_IMPL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_blis_impl( n, sa, cx, incx); +} + +void CSWAP_BLIS_IMPL(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_blis_impl( n, cx, incx, cy, incy); +} + +void cswap_blis_impl_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_blis_impl( n, cx, incx, cy, incy); +} + +void CSWAP_BLIS_IMPL_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_blis_impl( n, cx, incx, cy, incy); +} + +void CSYMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void csymm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYR2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void csyr2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYR2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYRK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void csyrk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CSYRK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CTBMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ctbmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ctbsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTPMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ctpmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void CTPMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void CTPSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ctpsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void CTPSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void CTRMM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ctrmm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRMM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ctrmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRSM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ctrsm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRSM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ctrsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +double DASUM_BLIS_IMPL(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_blis_impl( n, dx, incx); +} + +double dasum_blis_impl_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_blis_impl( n, dx, incx); +} + +double DASUM_BLIS_IMPL_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_blis_impl( n, dx, incx); +} + +void DAXPY_BLIS_IMPL(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_blis_impl( n, da, dx, incx, dy, incy); +} + +void daxpy_blis_impl_(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_blis_impl( n, da, dx, incx, dy, incy); +} + +void DAXPY_BLIS_IMPL_(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_blis_impl( n, da, dx, incx, dy, incy); +} + +double DCABS1_BLIS_IMPL(bla_dcomplex *z) +{ + return dcabs1_blis_impl( z); +} + +double dcabs1_blis_impl_(bla_dcomplex *z) +{ + return dcabs1_blis_impl( z); +} + +double DCABS1_BLIS_IMPL_(bla_dcomplex *z) +{ + return dcabs1_blis_impl( z); +} + +void DCOPY_BLIS_IMPL(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_blis_impl( n, dx, incx, dy, incy); +} + +void dcopy_blis_impl_(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_blis_impl( n, dx, incx, dy, incy); +} + +void DCOPY_BLIS_IMPL_(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_blis_impl( n, dx, incx, dy, incy); +} + +double DDOT_BLIS_IMPL(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_blis_impl( n, dx, incx, dy, incy); +} + +double ddot_blis_impl_(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_blis_impl( n, dx, incx, dy, incy); +} + +double DDOT_BLIS_IMPL_(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_blis_impl( n, dx, incx, dy, incy); +} + +void DGBMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void dgbmv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGBMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGEMM_BLIS_IMPL(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dgemm_blis_impl_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMM_BLIS_IMPL_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DZGEMM_BLIS_IMPL( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ) +{ + dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dzgemm_blis_impl_( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ) +{ + dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DZGEMM_BLIS_IMPL_( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ) +{ + dzgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void dgemv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGEMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGER_BLIS_IMPL(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void dger_blis_impl_(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void DGER_BLIS_IMPL_(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +double DNRM2_BLIS_IMPL(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_blis_impl( n, x, incx); +} + +double dnrm2_blis_impl_(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_blis_impl( n, x, incx); +} + +double DNRM2_BLIS_IMPL_(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_blis_impl( n, x, incx); +} + +void DROT_BLIS_IMPL(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_blis_impl( n, dx, incx, dy, incy, c, s); +} + +void drot_blis_impl_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_blis_impl( n, dx, incx, dy, incy, c, s); +} + +void DROT_BLIS_IMPL_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_blis_impl( n, dx, incx, dy, incy, c, s); +} + +void DROTG_BLIS_IMPL(double *da,double *db,double *c,double *s) +{ + drotg_blis_impl( da, db, c, s); +} + +void drotg_blis_impl_(double *da,double *db,double *c,double *s) +{ + drotg_blis_impl( da, db, c, s); +} + +void DROTG_BLIS_IMPL_(double *da,double *db,double *c,double *s) +{ + drotg_blis_impl( da, db, c, s); +} + +void DROTM_BLIS_IMPL(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_blis_impl( n, dx, incx, dy, incy, dparam); +} + +void drotm_blis_impl_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_blis_impl( n, dx, incx, dy, incy, dparam); +} + +void DROTM_BLIS_IMPL_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_blis_impl( n, dx, incx, dy, incy, dparam); +} + +void DROTMG_BLIS_IMPL(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); +} + +void drotmg_blis_impl_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); +} + +void DROTMG_BLIS_IMPL_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); +} + +void DSBMV_BLIS_IMPL(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void dsbmv_blis_impl_(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSBMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSCAL_BLIS_IMPL(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_blis_impl( n, da, dx, incx); +} + +void dscal_blis_impl_(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_blis_impl( n, da, dx, incx); +} + +void DSCAL_BLIS_IMPL_(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_blis_impl( n, da, dx, incx); +} + +double DSDOT_BLIS_IMPL(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_blis_impl( n, sx, incx, sy, incy); +} + +double dsdot_blis_impl_(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_blis_impl( n, sx, incx, sy, incy); +} + +double DSDOT_BLIS_IMPL_(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_blis_impl( n, sx, incx, sy, incy); +} + +void DSPMV_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void dspmv_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void DSPMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void DSPR_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void dspr_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void DSPR_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void DSPR2_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void dspr2_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void DSPR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void DSWAP_BLIS_IMPL(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_blis_impl( n, dx, incx, dy, incy); +} + +void dswap_blis_impl_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_blis_impl( n, dx, incx, dy, incy); +} + +void DSWAP_BLIS_IMPL_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_blis_impl( n, dx, incx, dy, incy); +} + +void DSYMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dsymm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYMV_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void dsymv_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSYMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSYR_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void dsyr_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void DSYR_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void DSYR2_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void dsyr2_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void DSYR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void DSYR2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dsyr2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYR2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYRK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void dsyrk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void DSYRK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void DTBMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void dtbmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void dtbsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTPMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void dtpmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void DTPMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void DTPSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void dtpsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void DTPSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void DTRMM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void dtrmm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRMM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void dtrmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRSM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void dtrsm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRSM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void dtrsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +double DZASUM_BLIS_IMPL(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_blis_impl( n, zx, incx); +} + +double dzasum_blis_impl_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_blis_impl( n, zx, incx); +} + +double DZASUM_BLIS_IMPL_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_blis_impl( n, zx, incx); +} + +double DZNRM2_BLIS_IMPL(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_blis_impl( n, x, incx); +} + +double dznrm2_blis_impl_(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_blis_impl( n, x, incx); +} + +double DZNRM2_BLIS_IMPL_(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_blis_impl( n, x, incx); +} + +f77_int ICAMAX_BLIS_IMPL(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_blis_impl( n, cx, incx); +} + +f77_int icamax_blis_impl_(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_blis_impl( n, cx, incx); +} + +f77_int ICAMAX_BLIS_IMPL_(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_blis_impl( n, cx, incx); +} + +f77_int IDAMAX_BLIS_IMPL(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_blis_impl( n, dx, incx); +} + +f77_int idamax_blis_impl_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_blis_impl( n, dx, incx); +} + +f77_int IDAMAX_BLIS_IMPL_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_blis_impl( n, dx, incx); +} + +f77_int ISAMAX_BLIS_IMPL(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_blis_impl( n, sx, incx); +} + +f77_int isamax_blis_impl_(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_blis_impl( n, sx, incx); +} + +f77_int ISAMAX_BLIS_IMPL_(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_blis_impl( n, sx, incx); +} + +f77_int IZAMAX_BLIS_IMPL(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_blis_impl( n, zx, incx); +} + +f77_int izamax_blis_impl_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_blis_impl( n, zx, incx); +} + +f77_int IZAMAX_BLIS_IMPL_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_blis_impl( n, zx, incx); +} + +f77_int LSAME_BLIS_IMPL(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_blis_impl( ca, cb, a, b); +} + +f77_int LSAME_BLIS_IMPL_(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_blis_impl( ca, cb, a, b); +} + +f77_int lsame_blis_impl_(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_blis_impl( ca, cb, a, b); +} + +float SASUM_BLIS_IMPL(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_blis_impl( n, sx, incx); +} + +float sasum_blis_impl_(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_blis_impl( n, sx, incx); +} + +float SASUM_BLIS_IMPL_(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_blis_impl( n, sx, incx); +} + +void SAXPY_BLIS_IMPL(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_blis_impl( n, sa, sx, incx, sy, incy); +} + +void saxpy_blis_impl_(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_blis_impl( n, sa, sx, incx, sy, incy); +} + +void SAXPY_BLIS_IMPL_(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_blis_impl( n, sa, sx, incx, sy, incy); +} + + +float SCASUM_BLIS_IMPL(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_blis_impl( n, cx, incx); +} + +float scasum_blis_impl_(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_blis_impl( n, cx, incx); +} + +float SCASUM_BLIS_IMPL_(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_blis_impl( n, cx, incx); +} + + + +float SCNRM2_BLIS_IMPL(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_blis_impl( n, x, incx); +} + +float scnrm2_blis_impl_(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_blis_impl( n, x, incx); +} + +float SCNRM2_BLIS_IMPL_(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_blis_impl( n, x, incx); +} + + +void SCOPY_BLIS_IMPL(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_blis_impl( n, sx, incx, sy, incy); +} + +void scopy_blis_impl_(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_blis_impl( n, sx, incx, sy, incy); +} + +void SCOPY_BLIS_IMPL_(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_blis_impl( n, sx, incx, sy, incy); +} + + +float SDOT_BLIS_IMPL(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_blis_impl( n, sx, incx, sy, incy); +} + +float sdot_blis_impl_(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_blis_impl( n, sx, incx, sy, incy); +} + +float SDOT_BLIS_IMPL_(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_blis_impl( n, sx, incx, sy, incy); +} + + +float SDSDOT_BLIS_IMPL(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_blis_impl( n, sb, sx, incx, sy, incy); +} + +float sdsdot_blis_impl_(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_blis_impl( n, sb, sx, incx, sy, incy); +} + +float SDSDOT_BLIS_IMPL_(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_blis_impl( n, sb, sx, incx, sy, incy); +} + + +void SGBMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void sgbmv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGBMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGEMM_BLIS_IMPL(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void sgemm_blis_impl_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMM_BLIS_IMPL_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void sgemv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGEMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGER_BLIS_IMPL(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void sger_blis_impl_(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void SGER_BLIS_IMPL_(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + + +float SNRM2_BLIS_IMPL(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_blis_impl( n, x, incx); +} + +float snrm2_blis_impl_(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_blis_impl( n, x, incx); +} + +float SNRM2_BLIS_IMPL_(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_blis_impl( n, x, incx); +} + + +void SROT_BLIS_IMPL(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_blis_impl( n, sx, incx, sy, incy, c, s); +} + +void srot_blis_impl_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_blis_impl( n, sx, incx, sy, incy, c, s); +} + +void SROT_BLIS_IMPL_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_blis_impl( n, sx, incx, sy, incy, c, s); +} + +void SROTG_BLIS_IMPL(float *sa,float *sb,float *c,float *s) +{ + srotg_blis_impl( sa, sb, c, s); +} + +void srotg_blis_impl_(float *sa,float *sb,float *c,float *s) +{ + srotg_blis_impl( sa, sb, c, s); +} + +void SROTG_BLIS_IMPL_(float *sa,float *sb,float *c,float *s) +{ + srotg_blis_impl( sa, sb, c, s); +} + +void SROTM_BLIS_IMPL(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_blis_impl( n, sx, incx, sy, incy, sparam); +} + +void srotm_blis_impl_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_blis_impl( n, sx, incx, sy, incy, sparam); +} + +void SROTM_BLIS_IMPL_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_blis_impl( n, sx, incx, sy, incy, sparam); +} + +void SROTMG_BLIS_IMPL(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); +} + +void srotmg_blis_impl_(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); +} + +void SROTMG_BLIS_IMPL_(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); +} + +void SSBMV_BLIS_IMPL(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ssbmv_blis_impl_(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSBMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSCAL_BLIS_IMPL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_blis_impl( n, sa, sx, incx); +} + +void sscal_blis_impl_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_blis_impl( n, sa, sx, incx); +} + +void SSCAL_BLIS_IMPL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_blis_impl( n, sa, sx, incx); +} + +void SSPMV_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void sspmv_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void SSPMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void SSPR_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void sspr_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void SSPR_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void SSPR2_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void sspr2_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void SSPR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void SSWAP_BLIS_IMPL(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_blis_impl( n, sx, incx, sy, incy); +} + +void sswap_blis_impl_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_blis_impl( n, sx, incx, sy, incy); +} + +void SSWAP_BLIS_IMPL_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_blis_impl( n, sx, incx, sy, incy); +} + +void SSYMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ssymm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYMV_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ssymv_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSYMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSYR_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void ssyr_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void SSYR_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void SSYR2_BLIS_IMPL(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ssyr2_blis_impl_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void SSYR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void SSYR2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ssyr2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYR2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYRK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ssyrk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void SSYRK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void STBMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void stbmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void stbsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STPMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void stpmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void STPMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void STPSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void stpsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void STPSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void STRMM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void strmm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRMM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void strmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRSM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void strsm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRSM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void strsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void XERBLA_BLIS_IMPL(const char *srname,const f77_int *info, ftnlen n) +{ + xerbla_blis_impl( srname, info, n); +} + +void XERBLA_BLIS_IMPL_(const char *srname,const f77_int *info, ftnlen n) +{ + xerbla_blis_impl( srname, info, n); +} + +void xerbla_blis_impl_(const char *srname,const f77_int *info, ftnlen n) +{ + xerbla_blis_impl( srname, info, n); +} + +void ZAXPY_BLIS_IMPL(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_blis_impl( n, za, zx, incx, zy, incy); +} + +void zaxpy_blis_impl_(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_blis_impl( n, za, zx, incx, zy, incy); +} + +void ZAXPY_BLIS_IMPL_(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_blis_impl( n, za, zx, incx, zy, incy); +} + +void ZCOPY_BLIS_IMPL(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_blis_impl( n, zx, incx, zy, incy); +} + +void zcopy_blis_impl_(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_blis_impl( n, zx, incx, zy, incy); +} + +void ZCOPY_BLIS_IMPL_(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_blis_impl( n, zx, incx, zy, incy); +} + +void ZDROT_BLIS_IMPL(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void zdrot_blis_impl_(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void ZDROT_BLIS_IMPL_(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_blis_impl( n, cx, incx, cy, incy, c, s); +} + +void ZDSCAL_BLIS_IMPL(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_blis_impl( n, da, zx, incx); +} + +void zdscal_blis_impl_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_blis_impl( n, da, zx, incx); +} + +void ZDSCAL_BLIS_IMPL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_blis_impl( n, da, zx, incx); +} + +void ZGBMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void zgbmv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGBMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGEMM_BLIS_IMPL(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemm_blis_impl_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM_BLIS_IMPL_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMV_BLIS_IMPL(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void zgemv_blis_impl_(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGEMV_BLIS_IMPL_(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGERC_BLIS_IMPL(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void zgerc_blis_impl_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERC_BLIS_IMPL_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERU_BLIS_IMPL(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void zgeru_blis_impl_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERU_BLIS_IMPL_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHBMV_BLIS_IMPL(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void zhbmv_blis_impl_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHBMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHEMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zhemm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHEMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHEMV_BLIS_IMPL(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void zhemv_blis_impl_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHEMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHER_BLIS_IMPL(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void zher_blis_impl_(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void ZHER_BLIS_IMPL_(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); +} + +void ZHER2_BLIS_IMPL(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void zher2_blis_impl_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHER2_BLIS_IMPL_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHER2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zher2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHER2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHERK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void zherk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZHERK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZHPMV_BLIS_IMPL(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void zhpmv_blis_impl_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void ZHPMV_BLIS_IMPL_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void ZHPR_BLIS_IMPL(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void zhpr_blis_impl_(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void ZHPR_BLIS_IMPL_(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); +} + +void ZHPR2_BLIS_IMPL(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void zhpr2_blis_impl_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void ZHPR2_BLIS_IMPL_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); +} + +void ZROTG_BLIS_IMPL(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_blis_impl( ca, cb, c, s); +} + +void zrotg_blis_impl_(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_blis_impl( ca, cb, c, s); +} + +void ZROTG_BLIS_IMPL_(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_blis_impl( ca, cb, c, s); +} + +void ZSCAL_BLIS_IMPL(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_blis_impl( n, za, zx, incx); +} + +void zscal_blis_impl_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_blis_impl( n, za, zx, incx); +} + +void ZSCAL_BLIS_IMPL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_blis_impl( n, za, zx, incx); +} + +void ZSWAP_BLIS_IMPL(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_blis_impl( n, zx, incx, zy, incy); +} + +void zswap_blis_impl_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_blis_impl( n, zx, incx, zy, incy); +} + +void ZSWAP_BLIS_IMPL_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_blis_impl( n, zx, incx, zy, incy); +} + +void ZSYMM_BLIS_IMPL(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zsymm_blis_impl_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYMM_BLIS_IMPL_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYR2K_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zsyr2k_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYR2K_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYRK_BLIS_IMPL(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void zsyrk_blis_impl_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZSYRK_BLIS_IMPL_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZTBMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ztbmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ztbsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTPMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ztpmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ztpsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); +} + +void ZTRMM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ztrmm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRMM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRMV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ztrmv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRMV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRSM_BLIS_IMPL(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ztrsm_blis_impl_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRSM_BLIS_IMPL_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRSV_BLIS_IMPL(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ztrsv_blis_impl_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRSV_BLIS_IMPL_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); +} + +#ifdef BLIS_ENABLE_CBLAS + +void CDOTCSUB_BLIS_IMPL( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void cdotcsub_blis_impl_( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void CDOTCSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void CDOTUSUB_BLIS_IMPL( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_blis_impl( n, x, incxy, y, incy, rval); +} + +void cdotusub_blis_impl_( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_blis_impl( n, x, incxy, y, incy, rval); +} + +void CDOTUSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_blis_impl( n, x, incxy, y, incy, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void CGEMM3M_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemm3m_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM3M_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void cgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void CGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void CGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +#ifdef BLIS_ENABLE_CBLAS + +void DASUMSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_blis_impl( n, x, incx, rval); +} + +void dasumsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_blis_impl( n, x, incx, rval); +} + +void DASUMSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void DAXPBY_BLIS_IMPL(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void daxpby_blis_impl_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void DAXPBY_BLIS_IMPL_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +#ifdef BLIS_ENABLE_CBLAS + +void DDOTSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_blis_impl( n, x, incx, y, incy, rval); +} + +void ddotsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_blis_impl( n, x, incx, y, incy, rval); +} + +void DDOTSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_blis_impl( n, x, incx, y, incy, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void DGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void dgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void DGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +f77_int DGEMM_PACK_GET_SIZE_BLIS_IMPL(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return dgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +f77_int dgemm_pack_get_size_blis_impl_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return dgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +f77_int DGEMM_PACK_GET_SIZE_BLIS_IMPL_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return dgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +void DGEMM_PACK_BLIS_IMPL( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ) +{ + dgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void dgemm_pack_blis_impl_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ) +{ + dgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void DGEMM_PACK_BLIS_IMPL_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ) +{ + dgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void DGEMM_COMPUTE_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + dgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void dgemm_compute_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + dgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void DGEMM_COMPUTE_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + dgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void DGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +#ifdef BLIS_ENABLE_CBLAS + +void DNRM2SUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_blis_impl( n, x, incx, rval); +} + +void dnrm2sub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_blis_impl( n, x, incx, rval); +} + +void DNRM2SUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +#ifdef BLIS_ENABLE_CBLAS + +void DZASUMSUB_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_blis_impl( n, x, incx, rval); +} + +void dzasumsub_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_blis_impl( n, x, incx, rval); +} + +void DZASUMSUB_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_blis_impl( n, x, incx, rval); +} + +void DZNRM2SUB_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_blis_impl( n, x, incx, rval); +} + +void dznrm2sub_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_blis_impl( n, x, incx, rval); +} + +void DZNRM2SUB_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_blis_impl( n, x, incx, rval); +} + +void ICAMAXSUB_BLIS_IMPL(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_blis_impl( n, x, incx, rval); +} + +void icamaxsub_blis_impl_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_blis_impl( n, x, incx, rval); +} + +void ICAMAXSUB_BLIS_IMPL_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +f77_int ICAMIN_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_blis_impl( n, x, incx); +} + +f77_int icamin_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_blis_impl( n, x, incx); +} + +f77_int ICAMIN_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_blis_impl( n, x, incx); +} + +#ifdef BLIS_ENABLE_CBLAS + +void ICAMINSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_blis_impl( n, x, incx, rval); +} + +void icaminsub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_blis_impl( n, x, incx, rval); +} + +void ICAMINSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_blis_impl( n, x, incx, rval); +} + +void IDAMAXSUB_BLIS_IMPL( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_blis_impl( n, x, incx, rval); +} + +void idamaxsub_blis_impl_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_blis_impl( n, x, incx, rval); +} + +void IDAMAXSUB_BLIS_IMPL_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +f77_int IDAMIN_BLIS_IMPL( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_blis_impl( n, x, incx); +} + +f77_int idamin_blis_impl_( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_blis_impl( n, x, incx); +} + +f77_int IDAMIN_BLIS_IMPL_( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_blis_impl( n, x, incx); +} + +#ifdef BLIS_ENABLE_CBLAS + +void IDAMINSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_blis_impl( n, x, incx, rval); +} + +void idaminsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_blis_impl( n, x, incx, rval); +} + +void IDAMINSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_blis_impl( n, x, incx, rval); +} + +void ISAMAXSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_blis_impl( n, x, incx, rval); +} + +void isamaxsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_blis_impl( n, x, incx, rval); +} + +void ISAMAXSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +f77_int ISAMIN_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_blis_impl( n, x, incx); +} + +f77_int isamin_blis_impl_( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_blis_impl( n, x, incx); +} + +f77_int ISAMIN_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_blis_impl( n, x, incx); +} + +#ifdef BLIS_ENABLE_CBLAS + +void ISAMINSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_blis_impl( n, x, incx, rval); +} + +void isaminsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_blis_impl( n, x, incx, rval); +} + +void ISAMINSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_blis_impl( n, x, incx, rval); +} + +void IZAMAXSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_blis_impl( n, x, incx, rval); +} + +void izamaxsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_blis_impl( n, x, incx, rval); +} + +void IZAMAXSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +f77_int IZAMIN_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_blis_impl( n, x, incx); +} + +f77_int izamin_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_blis_impl( n, x, incx); +} + +f77_int IZAMIN_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_blis_impl( n, x, incx); +} + +#ifdef BLIS_ENABLE_CBLAS + +void IZAMINSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_blis_impl( n, x, incx, rval); +} + +void izaminsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_blis_impl( n, x, incx, rval); +} + +void IZAMINSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_blis_impl( n, x, incx, rval); +} + +void SASUMSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_blis_impl( n, x, incx, rval); +} + +void sasumsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_blis_impl( n, x, incx, rval); +} + +void SASUMSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void SAXPBY_BLIS_IMPL( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void saxpby_blis_impl_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void SAXPBY_BLIS_IMPL_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +#ifdef BLIS_ENABLE_CBLAS + +void SCASUMSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_blis_impl( n, x, incx, rval); +} + +void scasumsub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_blis_impl( n, x, incx, rval); +} + +void SCASUMSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_blis_impl( n, x, incx, rval); +} + +void SCNRM2SUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_blis_impl( n, x, incx, rval); +} + +void scnrm2sub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_blis_impl( n, x, incx, rval); +} + +void SCNRM2SUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_blis_impl( n, x, incx, rval); +} + +void SDOTSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_blis_impl( n, x, incx, y, incy, rval); +} + +void sdotsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_blis_impl( n, x, incx, y, incy, rval); +} + +void SDOTSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_blis_impl( n, x, incx, y, incy, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void SGEMM_BATCH_BLIS_IMPL(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void sgemm_batch_blis_impl_(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void SGEMM_BATCH_BLIS_IMPL_(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +f77_int SGEMM_PACK_GET_SIZE_BLIS_IMPL(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return sgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +f77_int sgemm_pack_get_size_blis_impl_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return sgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +f77_int SGEMM_PACK_GET_SIZE_BLIS_IMPL_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk) +{ + return sgemm_pack_get_size_blis_impl( identifier, pm, pn, pk ); +} + +void SGEMM_PACK_BLIS_IMPL( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ) +{ + sgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void sgemm_pack_blis_impl_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ) +{ + sgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void SGEMM_PACK_BLIS_IMPL_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ) +{ + sgemm_pack_blis_impl( identifier, trans, mm, nn, kk, alpha, src, pld, dest ); +} + +void SGEMM_COMPUTE_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + sgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void sgemm_compute_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + sgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void SGEMM_COMPUTE_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ) +{ + f77_int rs_a = 1; + f77_int rs_b = 1; + f77_int rs_c = 1; + sgemm_compute_blis_impl( transa, transb, m, n, k, a, &rs_a, lda, b, &rs_b, ldb, beta, c, &rs_c, ldc ); +} + +void SGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void sgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +#ifdef BLIS_ENABLE_CBLAS + +void SNRM2SUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_blis_impl( n, x, incx, rval); +} + +void snrm2sub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_blis_impl( n, x, incx, rval); +} + +void SNRM2SUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_blis_impl( n, x, incx, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void ZAXPBY_BLIS_IMPL( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void zaxpby_blis_impl_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +void ZAXPBY_BLIS_IMPL_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); +} + +#ifdef BLIS_ENABLE_CBLAS + +void ZDOTCSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void zdotcsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void ZDOTCSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_blis_impl( n, x, incx, y, incy, rval); +} + +void ZDOTUSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_blis_impl( n, x, incx, y, incy, rval); +} + +void zdotusub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_blis_impl( n, x, incx, y, incy, rval); +} + +void ZDOTUSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_blis_impl( n, x, incx, y, incy, rval); +} + +#endif // BLIS_ENABLE_CBLAS + +void ZGEMM3M_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemm3m_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM3M_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void zgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void ZGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_blis_impl( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void ZGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +float SCABS1_BLIS_IMPL(bla_scomplex* z) +{ + return scabs1_blis_impl( z); +} + +float scabs1_blis_impl_(bla_scomplex* z) +{ + return scabs1_blis_impl( z); +} + +float SCABS1_BLIS_IMPL_(bla_scomplex* z) +{ + return scabs1_blis_impl( z); + +} + +#ifdef BLIS_ENABLE_CBLAS + +void SDSDOTSUB_BLIS_IMPL( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); +} + +void sdsdotsub_blis_impl_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); +} + +void SDSDOTSUB_BLIS_IMPL_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); +} + +void DSDOTSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_blis_impl( n, x, incx, y, incy, dot); +} + +void dsdotsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_blis_impl( n, x, incx, y, incy, dot); +} + +void DSDOTSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_blis_impl( n, x, incx, y, incy, dot); +} + +#endif // BLIS_ENABLE_CBLAS + +void CAXPBY_BLIS_IMPL( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); +} + +void caxpby_blis_impl_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); +} + +void CAXPBY_BLIS_IMPL_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); +} + +#endif +#endif diff --git a/frame/util/bli_util_api_wrap_blis_impl.h b/frame/util/bli_util_api_wrap_blis_impl.h new file mode 100644 index 0000000000..3da4f2ddef --- /dev/null +++ b/frame/util/bli_util_api_wrap_blis_impl.h @@ -0,0 +1,1677 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_UTIL_API_WRAP_BLIS_IMPL_H_ +#define BLI_UTIL_API_WRAP_BLIS_IMPL_H_ + +// file define different formats of BLAS _blis_impl APIs- uppercase with +// and without underscore, lowercase without underscore. + +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API +//Level 1 APIs +BLIS_EXPORT_BLIS void SROTG_BLIS_IMPL(float *sa, float *sb, float *c, float *s); + +BLIS_EXPORT_BLIS void srotg_blis_impl_(float *sa, float *sb, float *c, float *s); + +BLIS_EXPORT_BLIS void SROTG_BLIS_IMPL_(float *sa, float *sb, float *c, float *s); + + + +BLIS_EXPORT_BLIS void SROTMG_BLIS_IMPL(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + +BLIS_EXPORT_BLIS void srotmg_blis_impl_(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + +BLIS_EXPORT_BLIS void SROTMG_BLIS_IMPL_(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + + + +BLIS_EXPORT_BLIS void SROT_BLIS_IMPL(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void srot_blis_impl_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void SROT_BLIS_IMPL_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + + + +BLIS_EXPORT_BLIS void SROTM_BLIS_IMPL(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + +BLIS_EXPORT_BLIS void srotm_blis_impl_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + +BLIS_EXPORT_BLIS void SROTM_BLIS_IMPL_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + + + +BLIS_EXPORT_BLIS void SSWAP_BLIS_IMPL(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void sswap_blis_impl_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSWAP_BLIS_IMPL_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSCAL_BLIS_IMPL(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS void sscal_blis_impl_(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS void SSCAL_BLIS_IMPL_(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void SCOPY_BLIS_IMPL(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void scopy_blis_impl_(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SCOPY_BLIS_IMPL_(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SAXPY_BLIS_IMPL(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void saxpy_blis_impl_(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SAXPY_BLIS_IMPL_(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SDOT_BLIS_IMPL(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float sdot_blis_impl_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float SDOT_BLIS_IMPL_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SDSDOT_BLIS_IMPL(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float sdsdot_blis_impl_(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float SDSDOT_BLIS_IMPL_(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SNRM2_BLIS_IMPL(const f77_int *n, const float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float snrm2_blis_impl_(const f77_int *n, const float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float SNRM2_BLIS_IMPL_(const f77_int *n, const float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS float SCNRM2_BLIS_IMPL(const f77_int *n, const scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float scnrm2_blis_impl_(const f77_int *n, const scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float SCNRM2_BLIS_IMPL_(const f77_int *n, const scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS float SASUM_BLIS_IMPL(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS float sasum_blis_impl_(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS float SASUM_BLIS_IMPL_(const f77_int *n, const float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ISAMAX_BLIS_IMPL(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int isamax_blis_impl_(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int ISAMAX_BLIS_IMPL_(const f77_int *n, const float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DROTG_BLIS_IMPL(double *da, double *db, double *c, double *s); + +BLIS_EXPORT_BLIS void drotg_blis_impl_(double *da, double *db, double *c, double *s); + +BLIS_EXPORT_BLIS void DROTG_BLIS_IMPL_(double *da, double *db, double *c, double *s); + + + +BLIS_EXPORT_BLIS void DROTMG_BLIS_IMPL(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + +BLIS_EXPORT_BLIS void drotmg_blis_impl_(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + +BLIS_EXPORT_BLIS void DROTMG_BLIS_IMPL_(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + + + +BLIS_EXPORT_BLIS void DROT_BLIS_IMPL(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void drot_blis_impl_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void DROT_BLIS_IMPL_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + + + +BLIS_EXPORT_BLIS void DROTM_BLIS_IMPL(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + +BLIS_EXPORT_BLIS void drotm_blis_impl_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + +BLIS_EXPORT_BLIS void DROTM_BLIS_IMPL_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + + + +BLIS_EXPORT_BLIS void DSWAP_BLIS_IMPL(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void dswap_blis_impl_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSWAP_BLIS_IMPL_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSCAL_BLIS_IMPL(const f77_int *n, const double *da, double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS void dscal_blis_impl_(const f77_int *n, const double *da, double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS void DSCAL_BLIS_IMPL_(const f77_int *n, const double *da, double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DCOPY_BLIS_IMPL(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void dcopy_blis_impl_(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DCOPY_BLIS_IMPL_(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DAXPY_BLIS_IMPL(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void daxpy_blis_impl_(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DAXPY_BLIS_IMPL_(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DDOT_BLIS_IMPL(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS double ddot_blis_impl_(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS double DDOT_BLIS_IMPL_(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DSDOT_BLIS_IMPL(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS double dsdot_blis_impl_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS double DSDOT_BLIS_IMPL_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DNRM2_BLIS_IMPL(const f77_int *n, const double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double dnrm2_blis_impl_(const f77_int *n, const double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double DNRM2_BLIS_IMPL_(const f77_int *n, const double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS double DZNRM2_BLIS_IMPL(const f77_int *n, const dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double dznrm2_blis_impl_(const f77_int *n, const dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double DZNRM2_BLIS_IMPL_(const f77_int *n, const dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS double DASUM_BLIS_IMPL(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS double dasum_blis_impl_(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS double DASUM_BLIS_IMPL_(const f77_int *n, const double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int IDAMAX_BLIS_IMPL(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int idamax_blis_impl_(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int IDAMAX_BLIS_IMPL_(const f77_int *n, const double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CROTG_BLIS_IMPL(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + +BLIS_EXPORT_BLIS void crotg_blis_impl_(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + +BLIS_EXPORT_BLIS void CROTG_BLIS_IMPL_(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + + + +BLIS_EXPORT_BLIS void CSROT_BLIS_IMPL(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void csrot_blis_impl_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void CSROT_BLIS_IMPL_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + + + +BLIS_EXPORT_BLIS void CSWAP_BLIS_IMPL(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cswap_blis_impl_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CSWAP_BLIS_IMPL_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CSCAL_BLIS_IMPL(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void cscal_blis_impl_(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void CSCAL_BLIS_IMPL_(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + + +BLIS_EXPORT_BLIS void CSSCAL_BLIS_IMPL(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void csscal_blis_impl_(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void CSSCAL_BLIS_IMPL_(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + + +BLIS_EXPORT_BLIS void CCOPY_BLIS_IMPL(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ccopy_blis_impl_(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CCOPY_BLIS_IMPL_(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + + +BLIS_EXPORT_BLIS void CAXPY_BLIS_IMPL(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void caxpy_blis_impl_(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CAXPY_BLIS_IMPL_(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx,scomplex *cy, const f77_int *incy); + + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + +BLIS_EXPORT_BLIS scomplex CDOTC_BLIS_IMPL(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex cdotc_blis_impl_(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex CDOTC_BLIS_IMPL_(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS scomplex CDOTU_BLIS_IMPL(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex cdotu_blis_impl_(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex CDOTU_BLIS_IMPL_(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS dcomplex ZDOTC_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex zdotc_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex ZDOTC_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS dcomplex ZDOTU_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex zdotu_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex ZDOTU_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +#else + +BLIS_EXPORT_BLIS void CDOTC_BLIS_IMPL(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cdotc_blis_impl_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CDOTC_BLIS_IMPL_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CDOTU_BLIS_IMPL(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cdotu_blis_impl_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CDOTU_BLIS_IMPL_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZDOTC_BLIS_IMPL(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zdotc_blis_impl_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZDOTC_BLIS_IMPL_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZDOTU_BLIS_IMPL(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zdotu_blis_impl_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZDOTU_BLIS_IMPL_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +#endif + + +BLIS_EXPORT_BLIS float SCASUM_BLIS_IMPL(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS float scasum_blis_impl_(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS float SCASUM_BLIS_IMPL_(const f77_int *n, const scomplex *cx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ICAMAX_BLIS_IMPL(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int icamax_blis_impl_(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int ICAMAX_BLIS_IMPL_(const f77_int *n, const scomplex *cx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZROTG_BLIS_IMPL(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + +BLIS_EXPORT_BLIS void zrotg_blis_impl_(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + +BLIS_EXPORT_BLIS void ZROTG_BLIS_IMPL_(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + + + +BLIS_EXPORT_BLIS void ZDROT_BLIS_IMPL(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void zdrot_blis_impl_(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void ZDROT_BLIS_IMPL_(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + + + +BLIS_EXPORT_BLIS void ZSWAP_BLIS_IMPL(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zswap_blis_impl_(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZSWAP_BLIS_IMPL_(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZSCAL_BLIS_IMPL(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void zscal_blis_impl_(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZSCAL_BLIS_IMPL_(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZDSCAL_BLIS_IMPL(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void zdscal_blis_impl_(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZDSCAL_BLIS_IMPL_(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZCOPY_BLIS_IMPL(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zcopy_blis_impl_(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZCOPY_BLIS_IMPL_(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZAXPY_BLIS_IMPL(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zaxpy_blis_impl_(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZAXPY_BLIS_IMPL_(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DZASUM_BLIS_IMPL(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS double dzasum_blis_impl_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS double DZASUM_BLIS_IMPL_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int IZAMAX_BLIS_IMPL(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int izamax_blis_impl_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int IZAMAX_BLIS_IMPL_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ICAMIN_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int icamin_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int ICAMIN_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int IDAMIN_BLIS_IMPL( const f77_int* n, const double* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int idamin_blis_impl_( const f77_int* n, const double* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int IDAMIN_BLIS_IMPL_( const f77_int* n, const double* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int ISAMIN_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int isamin_blis_impl_( const f77_int* n, const float* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int ISAMIN_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int IZAMIN_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int izamin_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int IZAMIN_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx); + + + +//Level 2 APIs +BLIS_EXPORT_BLIS void SGEMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sgemv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SGEMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SGBMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sgbmv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SGBMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSYMV_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ssymv_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSYMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSBMV_BLIS_IMPL(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ssbmv_blis_impl_(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSBMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSPMV_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sspmv_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSPMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void STRMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void strmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STRMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STBMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stbmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STBMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STPMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stpmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STPMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STRSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void strsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STRSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STBSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stbsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STBSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STPSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stpsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STPSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void SGER_BLIS_IMPL(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void sger_blis_impl_(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SGER_BLIS_IMPL_(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSYR_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ssyr_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SSYR_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSPR_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + +BLIS_EXPORT_BLIS void sspr_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + +BLIS_EXPORT_BLIS void SSPR_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + + + +BLIS_EXPORT_BLIS void SSYR2_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ssyr2_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SSYR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSPR2_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + +BLIS_EXPORT_BLIS void sspr2_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + +BLIS_EXPORT_BLIS void SSPR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + + + +BLIS_EXPORT_BLIS void DGEMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dgemv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DGEMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DGBMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dgbmv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DGBMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSYMV_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dsymv_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSYMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSBMV_BLIS_IMPL(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dsbmv_blis_impl_(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSBMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSPMV_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dspmv_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSPMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DTRMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtrmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTRMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTBMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtbmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTBMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTPMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtpmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTPMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTRSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtrsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTRSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTBSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtbsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTBSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTPSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtpsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTPSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DGER_BLIS_IMPL(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dger_blis_impl_(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DGER_BLIS_IMPL_(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSYR_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dsyr_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DSYR_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSPR_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + +BLIS_EXPORT_BLIS void dspr_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + +BLIS_EXPORT_BLIS void DSPR_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + + + +BLIS_EXPORT_BLIS void DSYR2_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dsyr2_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DSYR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSPR2_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + +BLIS_EXPORT_BLIS void dspr2_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + +BLIS_EXPORT_BLIS void DSPR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + + + +BLIS_EXPORT_BLIS void CGEMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void cgemv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CGEMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CGBMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void cgbmv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CGBMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHEMV_BLIS_IMPL(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chemv_blis_impl_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHEMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHBMV_BLIS_IMPL(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chbmv_blis_impl_(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHBMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a,const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHPMV_BLIS_IMPL(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chpmv_blis_impl_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHPMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CTRMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctrmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTRMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTBMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctbmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTBMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTPMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctpmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTPMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTRSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctrsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTRSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTBSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctbsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTBSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTPSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctpsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTPSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CGERC_BLIS_IMPL(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cgerc_blis_impl_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CGERC_BLIS_IMPL_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CGERU_BLIS_IMPL(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cgeru_blis_impl_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CGERU_BLIS_IMPL_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHER_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cher_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CHER_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHPR_BLIS_IMPL(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + +BLIS_EXPORT_BLIS void chpr_blis_impl_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + +BLIS_EXPORT_BLIS void CHPR_BLIS_IMPL_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + + + +BLIS_EXPORT_BLIS void CHER2_BLIS_IMPL(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cher2_blis_impl_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CHER2_BLIS_IMPL_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHPR2_BLIS_IMPL(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + +BLIS_EXPORT_BLIS void chpr2_blis_impl_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + +BLIS_EXPORT_BLIS void CHPR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + + + +BLIS_EXPORT_BLIS void ZGEMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zgemv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZGEMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZGBMV_BLIS_IMPL(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zgbmv_blis_impl_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZGBMV_BLIS_IMPL_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHEMV_BLIS_IMPL(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhemv_blis_impl_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHEMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHBMV_BLIS_IMPL(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhbmv_blis_impl_(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHBMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHPMV_BLIS_IMPL(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhpmv_blis_impl_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHPMV_BLIS_IMPL_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZTRMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztrmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTRMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTBMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztbmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTBMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTPMV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztpmv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTPMV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTRSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztrsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTRSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTBSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztbsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTBSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTPSV_BLIS_IMPL(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztpsv_blis_impl_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTPSV_BLIS_IMPL_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZGERU_BLIS_IMPL(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zgeru_blis_impl_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZGERU_BLIS_IMPL_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZGERC_BLIS_IMPL(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zgerc_blis_impl_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZGERC_BLIS_IMPL_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHER_BLIS_IMPL(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zher_blis_impl_(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZHER_BLIS_IMPL_(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHPR_BLIS_IMPL(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + +BLIS_EXPORT_BLIS void zhpr_blis_impl_(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + +BLIS_EXPORT_BLIS void ZHPR_BLIS_IMPL_(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + + + +BLIS_EXPORT_BLIS void ZHER2_BLIS_IMPL(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zher2_blis_impl_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZHER2_BLIS_IMPL_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHPR2_BLIS_IMPL(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + +BLIS_EXPORT_BLIS void zhpr2_blis_impl_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + +BLIS_EXPORT_BLIS void ZHPR2_BLIS_IMPL_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + + + +//Level 3 APIs +BLIS_EXPORT_BLIS void SGEMM_BLIS_IMPL(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void sgemm_blis_impl_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SGEMM_BLIS_IMPL_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssymm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYRK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssyrk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYRK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYR2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssyr2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYR2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void STRMM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void strmm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void STRMM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void STRSM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void strsm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void STRSM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void DGEMM_BLIS_IMPL(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dgemm_blis_impl_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DGEMM_BLIS_IMPL_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DZGEMM_BLIS_IMPL( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ); + +BLIS_EXPORT_BLIS void dzgemm_blis_impl_( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ); + +BLIS_EXPORT_BLIS void DZGEMM_BLIS_IMPL_( const f77_char *transa, const f77_char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const double *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc ); + + + +BLIS_EXPORT_BLIS void DSYMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsymm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DSYRK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsyrk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYRK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DSYR2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsyr2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYR2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DTRMM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void dtrmm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void DTRMM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void DTRSM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void dtrsm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void DTRSM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void CGEMM_BLIS_IMPL(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cgemm_blis_impl_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CGEMM_BLIS_IMPL_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csymm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHEMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void chemm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHEMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYRK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csyrk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYRK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHERK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cherk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHERK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYR2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csyr2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYR2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHER2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cher2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHER2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CTRMM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ctrmm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void CTRMM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void CTRSM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ctrsm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void CTRSM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void ZGEMM_BLIS_IMPL(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zgemm_blis_impl_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZGEMM_BLIS_IMPL_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsymm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHEMM_BLIS_IMPL(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zhemm_blis_impl_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHEMM_BLIS_IMPL_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYRK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsyrk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYRK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHERK_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zherk_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHERK_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYR2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsyr2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYR2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHER2K_BLIS_IMPL(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zher2k_blis_impl_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHER2K_BLIS_IMPL_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZTRMM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ztrmm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ZTRMM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void ZTRSM_BLIS_IMPL(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ztrsm_blis_impl_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ZTRSM_BLIS_IMPL_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + + + +// Miscellaneous APIs + +#ifdef BLIS_ENABLE_CBLAS + +BLIS_EXPORT_BLIS void CDOTCSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void cdotcsub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void CDOTCSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + + + +BLIS_EXPORT_BLIS void CDOTUSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void cdotusub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void CDOTUSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + + + +BLIS_EXPORT_BLIS void DASUMSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dasumsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DASUMSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void DDOTSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + +BLIS_EXPORT_BLIS void ddotsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + +BLIS_EXPORT_BLIS void DDOTSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + + + +BLIS_EXPORT_BLIS void DNRM2SUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, double *rval); + +BLIS_EXPORT_BLIS void dnrm2sub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, double *rval); + +BLIS_EXPORT_BLIS void DNRM2SUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, double *rval); + + + +BLIS_EXPORT_BLIS void DZASUMSUB_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dzasumsub_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DZASUMSUB_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void DZNRM2SUB_BLIS_IMPL(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dznrm2sub_blis_impl_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DZNRM2SUB_BLIS_IMPL_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void ICAMAXSUB_BLIS_IMPL(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void icamaxsub_blis_impl_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ICAMAXSUB_BLIS_IMPL_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ICAMINSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void icaminsub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ICAMINSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IDAMAXSUB_BLIS_IMPL( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void idamaxsub_blis_impl_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IDAMAXSUB_BLIS_IMPL_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IDAMINSUB_BLIS_IMPL(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void idaminsub_blis_impl_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IDAMINSUB_BLIS_IMPL_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ISAMAXSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void isamaxsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ISAMAXSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ISAMINSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void isaminsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ISAMINSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IZAMINSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void izaminsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IZAMINSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IZAMAXSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void izamaxsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IZAMAXSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void SASUMSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void sasumsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SASUMSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SCASUMSUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void scasumsub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SCASUMSUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SCNRM2SUB_BLIS_IMPL( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void scnrm2sub_blis_impl_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SCNRM2SUB_BLIS_IMPL_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SDOTSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + +BLIS_EXPORT_BLIS void sdotsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + +BLIS_EXPORT_BLIS void SDOTSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + + + +BLIS_EXPORT_BLIS void SNRM2SUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, float *rval); + +BLIS_EXPORT_BLIS void snrm2sub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, float *rval); + +BLIS_EXPORT_BLIS void SNRM2SUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, float *rval); + + + +BLIS_EXPORT_BLIS void ZDOTCSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void zdotcsub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void ZDOTCSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + + + +BLIS_EXPORT_BLIS void ZDOTUSUB_BLIS_IMPL( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void zdotusub_blis_impl_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void ZDOTUSUB_BLIS_IMPL_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + + + +BLIS_EXPORT_BLIS void SDSDOTSUB_BLIS_IMPL( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + +BLIS_EXPORT_BLIS void sdsdotsub_blis_impl_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + +BLIS_EXPORT_BLIS void SDSDOTSUB_BLIS_IMPL_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + + + +BLIS_EXPORT_BLIS void DSDOTSUB_BLIS_IMPL( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + +BLIS_EXPORT_BLIS void dsdotsub_blis_impl_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + +BLIS_EXPORT_BLIS void DSDOTSUB_BLIS_IMPL_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + +#endif // BLIS_ENABLE_CBLAS + + +BLIS_EXPORT_BLIS f77_int LSAME_BLIS_IMPL(const char *ca, const char *cb, const f77_int a, const f77_int b); + +BLIS_EXPORT_BLIS f77_int lsame_blis_impl_(const char *ca, const char *cb, const f77_int a, const f77_int b); + +BLIS_EXPORT_BLIS f77_int LSAME_BLIS_IMPL_(const char *ca, const char *cb, const f77_int a, const f77_int b); + + + +BLIS_EXPORT_BLIS void XERBLA_BLIS_IMPL(const char *srname, const f77_int *info, ftnlen n); + +BLIS_EXPORT_BLIS void xerbla_blis_impl_(const char *srname, const f77_int *info, ftnlen n); + +BLIS_EXPORT_BLIS void XERBLA_BLIS_IMPL_(const char *srname, const f77_int *info, ftnlen n); + + + +//Auxiliary APIs +BLIS_EXPORT_BLIS double DCABS1_BLIS_IMPL(bla_dcomplex *z); + +BLIS_EXPORT_BLIS double dcabs1_blis_impl_(bla_dcomplex *z); + +BLIS_EXPORT_BLIS double DCABS1_BLIS_IMPL_(bla_dcomplex *z); + + + +BLIS_EXPORT_BLIS float SCABS1_BLIS_IMPL(bla_scomplex* z); + +BLIS_EXPORT_BLIS float scabs1_blis_impl_(bla_scomplex* z); + +BLIS_EXPORT_BLIS float SCABS1_BLIS_IMPL_(bla_scomplex* z); + + + +//BLAS Extension APIs +BLIS_EXPORT_BLIS void CAXPBY_BLIS_IMPL( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void caxpby_blis_impl_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void CAXPBY_BLIS_IMPL_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void CGEMM3M_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void cgemm3m_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void CGEMM3M_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void CGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void cgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void CGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void CGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void cgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void CGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void DAXPBY_BLIS_IMPL(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void daxpby_blis_impl_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void DAXPBY_BLIS_IMPL_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void DGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void dgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void DGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS f77_int DGEMM_PACK_GET_SIZE_BLIS_IMPL(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + +BLIS_EXPORT_BLIS f77_int dgemm_pack_get_size_blis_impl_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + +BLIS_EXPORT_BLIS f77_int DGEMM_PACK_GET_SIZE_BLIS_IMPL_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + + + +BLIS_EXPORT_BLIS void DGEMM_PACK_BLIS_IMPL( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ); + +BLIS_EXPORT_BLIS void dgemm_pack_blis_impl_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ); + +BLIS_EXPORT_BLIS void DGEMM_PACK_BLIS_IMPL_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const double* alpha, const double* src, const f77_int* pld, double* dest ); + + + +BLIS_EXPORT_BLIS void DGEMM_COMPUTE_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ); + +BLIS_EXPORT_BLIS void dgemm_compute_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ); + +BLIS_EXPORT_BLIS void DGEMM_COMPUTE_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc ); + + + +BLIS_EXPORT_BLIS void DGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void dgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void DGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void SAXPBY_BLIS_IMPL( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void saxpby_blis_impl_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void SAXPBY_BLIS_IMPL_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void SGEMM_BATCH_BLIS_IMPL(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void sgemm_batch_blis_impl_(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void SGEMM_BATCH_BLIS_IMPL_(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS f77_int SGEMM_PACK_GET_SIZE_BLIS_IMPL(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + +BLIS_EXPORT_BLIS f77_int sgemm_pack_get_size_blis_impl_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + +BLIS_EXPORT_BLIS f77_int SGEMM_PACK_GET_SIZE_BLIS_IMPL_(const f77_char* identifier, const f77_int* pm, const f77_int* pn, const f77_int* pk); + + + +BLIS_EXPORT_BLIS void SGEMM_PACK_BLIS_IMPL( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ); + +BLIS_EXPORT_BLIS void sgemm_pack_blis_impl_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ); + +BLIS_EXPORT_BLIS void SGEMM_PACK_BLIS_IMPL_( const f77_char* identifier, const f77_char* trans, const f77_int* mm, const f77_int* nn, const f77_int* kk, const float* alpha, const float* src, const f77_int* pld, float* dest ); + + + +BLIS_EXPORT_BLIS void SGEMM_COMPUTE_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ); + +BLIS_EXPORT_BLIS void sgemm_compute_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ); + +BLIS_EXPORT_BLIS void SGEMM_COMPUTE_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc ); + + + +BLIS_EXPORT_BLIS void SGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void sgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void SGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void ZAXPBY_BLIS_IMPL( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void zaxpby_blis_impl_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void ZAXPBY_BLIS_IMPL_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void ZGEMM3M_BLIS_IMPL( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void zgemm3m_blis_impl_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void ZGEMM3M_BLIS_IMPL_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void ZGEMM_BATCH_BLIS_IMPL( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void zgemm_batch_blis_impl_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void ZGEMM_BATCH_BLIS_IMPL_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void ZGEMMT_BLIS_IMPL( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void zgemmt_blis_impl_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void ZGEMMT_BLIS_IMPL_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +#endif +#endif + +#endif // BLI_UTIL_API_WRAP_BLIS_IMPL_H_ diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index 22de9b152d..0061c6e8e4 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -325,22 +325,23 @@ void bli_cnormfv_unb_var1 arch_t id = bli_arch_query_id(); switch ( id ) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN3: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN:; #ifdef BLIS_KERNELS_ZEN - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_pba_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_buf_X = { 0 }; - rntm_t rntm_l; - // Packing for non-unit strided vector x. - if ( incx != 1 ) + // Handling the kernel call in case of non-unit strides + if ( ( incx != 1 ) && ( incx != 0 ) ) { + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_pba_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_buf_X = { 0 }; + rntm_t rntm_l; // In order to get the buffer from pool via rntm access to memory broker // is needed. Following are initializations for rntm. if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); } @@ -376,18 +377,24 @@ void bli_cnormfv_unb_var1 } incx_buf = 1; } - } - bli_scnorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); + bli_scnorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); - if ( bli_mem_is_alloc( &mem_buf_X ) ) + if ( bli_mem_is_alloc( &mem_buf_X ) ) + { + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool. + bli_pba_release( &rntm_l , &mem_buf_X ); + } + } + else { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_pba_release( &rntm_l , &mem_buf_X ); + // Call the kernel with the unit-strided vector x + bli_scnorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); } + break; #endif default:; @@ -447,10 +454,16 @@ void bli_znormfv_unb_var1 void ( *reduce_fp )( dim_t, double*, inc_t, double*, cntx_t* ) = NULL; dcomplex *x_buf = x; + dim_t fast_path_thresh = 1; +#ifdef BLIS_ENABLE_OPENMP dim_t nt_ideal = -1; + dim_t simd_factor = 1; +#endif + arch_t id = bli_arch_query_id(); switch ( id ) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN3: case BLIS_ARCH_ZEN2: @@ -459,6 +472,11 @@ void bli_znormfv_unb_var1 norm_fp = bli_dznorm2fv_unb_var1_avx2; reduce_fp = bli_dnorm2fv_unb_var1_avx2; + fast_path_thresh = 2000; + + #ifdef BLIS_ENABLE_OPENMP + simd_factor = 2; + #endif break; #endif @@ -504,37 +522,49 @@ void bli_znormfv_unb_var1 return; /* - When the size is such that nt_ideal is 1, and packing is not - required( incx == 1 ), we can directly call the kernel to - avoid framework overheads( fast-path ). + Call the kernel directly in these two cases : + - When incx == 0, since the norm is based on only one dcomplex + element( two real double precision elements ) + - When the size is such that nt_ideal is 1, and packing is not + required( incx == 1 ), we can directly call the kernel to + avoid framework overheads( fast-path ). */ - else if ( ( incx == 1 ) && ( n < 2000 ) ) + else if ( ( incx == 0 ) || ( ( incx == 1 ) && ( n < fast_path_thresh ) ) ) { norm_fp( n, x, incx, norm, cntx ); return; } - // Setting the ideal number of threads if support is enabled - #if defined( BLIS_ENABLE_OPENMP ) && defined( AOCL_DYNAMIC ) - if ( n < 2000 ) - nt_ideal = 1; - else if ( n < 6500 ) - nt_ideal = 4; - else if ( n < 71000 ) - nt_ideal = 8; - else if ( n < 200000 ) - nt_ideal = 16; - else if ( n < 1530000 ) - nt_ideal = 32; - - #endif - // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_l; if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); } else { rntm_l = *rntm; } + // Setting the ideal number of threads if support is enabled + #if defined( BLIS_ENABLE_OPENMP ) + + #if defined( AOCL_DYNAMIC ) + aocl_znormfv_dynamic + ( + id, + n, + &nt_ideal + ); + #endif + + // Variable to acquire threads from runtime + dim_t nt; + nt = bli_rntm_num_threads( &rntm_l ); + + // nt is less than 1 if BLIS was configured with default settings for parallelism + nt = ( nt < 1 )? 1 : nt; + + if ( ( nt_ideal == -1 ) || ( nt_ideal > nt ) ) + nt_ideal = nt; + + #endif + /* Initialize mem pool buffer to NULL and size to 0 "buf" and "size" fields are assigned once memory @@ -545,16 +575,6 @@ void bli_znormfv_unb_var1 mem_t mem_buf_X = { 0 }; inc_t incx_buf = incx; - dim_t nt; - - nt = bli_rntm_num_threads( &rntm_l ); - - // nt is less than 1 if BLIS was configured with default settings for parallelism - nt = ( nt < 1 )? 1 : nt; - - // Altering the ideal thread count if it was not set or if it is greater than nt - if ( ( nt_ideal == -1 ) || ( nt_ideal > nt ) ) - nt_ideal = nt; // Packing for non-unit strided vector x. // In order to get the buffer from pool via rntm access to memory broker @@ -562,8 +582,7 @@ void bli_znormfv_unb_var1 bli_rntm_set_num_threads_only( 1, &rntm_l ); bli_pba_rntm_set_pba( &rntm_l ); - if ( incx == 0 ) nt_ideal = 1; - else if ( incx != 1 ) + if ( incx != 1 ) { // Calculate the size required for "n" double elements in vector x. size_t buffer_size = n * sizeof( dcomplex ); @@ -593,10 +612,14 @@ void bli_znormfv_unb_var1 } incx_buf = 1; } + // Resort to using single-threaded kernel call if packing fails, + // since we execute non-unit strided code section. + #ifdef BLIS_ENABLE_OPENMP else { nt_ideal = 1; } + #endif } #ifdef BLIS_ENABLE_OPENMP @@ -684,7 +707,7 @@ void bli_znormfv_unb_var1 // Obtain the job-size and region for compute dim_t job_per_thread, offset; - bli_normfv_thread_partition( n, n_threads, &offset, &job_per_thread, 2, incx_buf, thread_id ); + bli_normfv_thread_partition( n, n_threads, &offset, &job_per_thread, simd_factor, incx_buf, thread_id ); x_start = x_buf + offset; // Call to the kernel with the appropriate starting address @@ -893,22 +916,23 @@ void bli_snormfv_unb_var1 arch_t id = bli_arch_query_id(); switch ( id ) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: case BLIS_ARCH_ZEN3: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN:; #ifdef BLIS_KERNELS_ZEN - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_pba_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_buf_X = { 0 }; - rntm_t rntm_l; - // Packing for non-unit strided vector x. - if ( incx != 1 ) + // Handling the kernel call in case of non-unit strides + if ( ( incx != 1 ) && ( incx != 0 ) ) { + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_pba_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_buf_X = { 0 }; + rntm_t rntm_l; // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); } @@ -947,18 +971,24 @@ void bli_snormfv_unb_var1 } incx_buf = 1; } - } - bli_snorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); + bli_snorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); - if ( bli_mem_is_alloc( &mem_buf_X ) ) + if ( bli_mem_is_alloc( &mem_buf_X ) ) + { + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool. + bli_pba_release( &rntm_l , &mem_buf_X ); + } + } + else { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_pba_release( &rntm_l , &mem_buf_X ); + // Call the kernel with the unit-strided vector x + bli_snorm2fv_unb_var1_avx2( n, x_buf, incx_buf, norm, cntx ); } + break; #endif default:; @@ -1022,17 +1052,39 @@ void bli_dnormfv_unb_var1 void ( *norm_fp )( dim_t, double*, inc_t, double*, cntx_t* ) = NULL; double *x_buf = x; + dim_t fast_path_thresh = 1; +#ifdef BLIS_ENABLE_OPENMP + dim_t simd_factor = 1; dim_t nt_ideal = -1; +#endif + arch_t id = bli_arch_query_id(); switch ( id ) { + case BLIS_ARCH_ZEN5: case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + + norm_fp = bli_dnorm2fv_unb_var1_avx512; + fast_path_thresh = 4500; + + #ifdef BLIS_ENABLE_OPENMP + simd_factor = 8; + #endif + + break; +#endif case BLIS_ARCH_ZEN3: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN: #ifdef BLIS_KERNELS_ZEN norm_fp = bli_dnorm2fv_unb_var1_avx2; + fast_path_thresh = 4000; + + #ifdef BLIS_ENABLE_OPENMP + simd_factor = 4; + #endif break; #endif @@ -1077,38 +1129,49 @@ void bli_dnormfv_unb_var1 return; /* - When the size is such that nt_ideal is 1, and packing is not - required( incx == 1 ), we can directly call the kernel to - avoid framework overheads( fast-path ). + Call the kernel directly in these two cases : + - When incx == 0, since the norm is based on only one dcomplex + element( two real double precision elements ) + - When the size is such that nt_ideal is 1, and packing is not + required( incx == 1 ), we can directly call the kernel to + avoid framework overheads( fast-path ). */ - else if ( ( incx == 1 ) && ( n < 4000 ) ) + else if ( ( incx == 0 ) || ( ( incx == 1 ) && ( n < fast_path_thresh ) ) ) { norm_fp( n, x, incx, norm, cntx ); return; } - // Setting the ideal number of threads if support is enabled - #if defined( BLIS_ENABLE_OPENMP ) && defined( AOCL_DYNAMIC ) - - if ( n < 4000 ) - nt_ideal = 1; - else if ( n < 17000 ) - nt_ideal = 4; - else if ( n < 136000 ) - nt_ideal = 8; - else if ( n < 365000 ) - nt_ideal = 16; - else if ( n < 2950000 ) - nt_ideal = 32; - - #endif - // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_l; if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); } else { rntm_l = *rntm; } + // Setting the ideal number of threads if support is enabled + #if defined( BLIS_ENABLE_OPENMP ) + + #if defined( AOCL_DYNAMIC ) + aocl_dnormfv_dynamic + ( + id, + n, + &nt_ideal + ); + #endif + + // Variable to acquire threads from runtime + dim_t nt; + nt = bli_rntm_num_threads( &rntm_l ); + + // nt is less than 1 if BLIS was configured with default settings for parallelism + nt = ( nt < 1 )? 1 : nt; + + if ( ( nt_ideal == -1 ) || ( nt_ideal > nt ) ) + nt_ideal = nt; + + #endif + /* Initialize mem pool buffer to NULL and size to 0 "buf" and "size" fields are assigned once memory @@ -1119,15 +1182,6 @@ void bli_dnormfv_unb_var1 mem_t mem_buf_X = { 0 }; inc_t incx_buf = incx; - dim_t nt; - - nt = bli_rntm_num_threads( &rntm_l ); - - // nt is less than 1 if BLIS was configured with default settings for parallelism - nt = ( nt < 1 )? 1 : nt; - - if ( ( nt_ideal == -1 ) || ( nt_ideal > nt ) ) - nt_ideal = nt; // Packing for non-unit strided vector x. // In order to get the buffer from pool via rntm access to memory broker @@ -1135,8 +1189,7 @@ void bli_dnormfv_unb_var1 bli_rntm_set_num_threads_only( 1, &rntm_l ); bli_pba_rntm_set_pba( &rntm_l ); - if ( incx == 0 ) nt_ideal = 1; - else if ( incx != 1 ) + if ( incx != 1 ) { // Calculate the size required for "n" double elements in vector x. size_t buffer_size = n * sizeof( double ); @@ -1166,10 +1219,14 @@ void bli_dnormfv_unb_var1 } incx_buf = 1; } + // In case packing fails, we use the original buffer. We have to make sure that + // we reset the number of threads to 1 if we have enabled openmp for multithreading. + #ifdef BLIS_ENABLE_OPENMP else { nt_ideal = 1; } + #endif } #ifdef BLIS_ENABLE_OPENMP @@ -1257,7 +1314,7 @@ void bli_dnormfv_unb_var1 // Obtain the job-size and region for compute dim_t job_per_thread, offset; - bli_normfv_thread_partition( n, n_threads, &offset, &job_per_thread, 4, incx_buf, thread_id ); + bli_normfv_thread_partition( n, n_threads, &offset, &job_per_thread, simd_factor, incx_buf, thread_id ); x_start = x_buf + offset; diff --git a/gtestsuite/CMakeLists.txt b/gtestsuite/CMakeLists.txt index 78d8906a11..6b1339570a 100644 --- a/gtestsuite/CMakeLists.txt +++ b/gtestsuite/CMakeLists.txt @@ -1,21 +1,22 @@ #[=[ + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -28,9 +29,10 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ]=] -cmake_minimum_required(VERSION 3.20.0) +cmake_minimum_required(VERSION 3.22.0) set(CMAKE_CXX_COMPILER ${CXX_COMPILER}) set(CMAKE_CXX_STANDARD 17) @@ -49,9 +51,10 @@ if(APPLE) endif() # Set the path to the BLIS installation. -set(BLIS_PATH "undefined" CACHE STRING "Setting the path to a BLIS installation that needs testing.") -if(BLIS_PATH STREQUAL "undefined") - message(FATAL_ERROR "Need to provide a BLIS installation path during CMake invocation. Please use \ +set(BLIS_PATH $ENV{AOCL_BLAS_PATH} CACHE STRING "Setting the path to a BLIS installation that needs testing.") +if(BLIS_PATH STREQUAL "") + message(FATAL_ERROR "Need to provide a BLIS installation path during CMake invocation.\ + Set environment variable \$AOCL_BLAS_PATH or set the cmake variable directly using\ $ cmake .. -DBLIS_PATH=/home/username/blis_installation") endif() @@ -60,6 +63,11 @@ endif() set(BLIS_INCLUDE ${BLIS_PATH}/include/ ${BLIS_PATH}/include/blis CACHE STRING "Setting the path to the BLIS headers.") set(BLIS_LIB_PATH ${BLIS_PATH}/lib CACHE STRING "Setting the path to the BLIS library.") +# Use REF_BLAS to set the library that will be used for reference results. +set(REF_CBLAS "Netlib" CACHE STRING "Library used to compute reference results.") +# Use REF_LIB to set the library that will be used for reference results. +set(REF_LIB $ENV{CBLAS_REF_LIB} CACHE STRING "Path to a shared library that will be used as a reference.") + # Set OpenMP as the default option set(ENABLE_THREADING "openmp" CACHE STRING "the threading flag") # Set the possible values of theading libraries for cmake-gui @@ -108,8 +116,8 @@ endif() # Set common libraries. if(LINUX) set(COMMON_LIBS pthread m dl) - option(ENABLE_ASAN "Run tests using Address Sanatizer" OFF) - option(ENABLE_COVERAGE "Run tests for Code Coderage" OFF) + option(ENABLE_ASAN "Run tests using Address Sanitizer" OFF) + option(ENABLE_COVERAGE "Run tests for Code Coverage" OFF) endif() # Use INT_SIZE to set the int type used for testing. @@ -124,10 +132,10 @@ endif() # Use TEST_INTERFACE to set which interface, supported by BLIS is meant to be tested. set(TEST_INTERFACE "BLAS" CACHE STRING "Interface of BLIS that is being tested.") # Set the possible values of interfaces for cmake-gui -set_property(CACHE TEST_INTERFACE PROPERTY STRINGS "BLAS" "CBLAS" "BLIS_TYPED") -if( NOT ((TEST_INTERFACE STREQUAL "BLAS") OR (TEST_INTERFACE STREQUAL "CBLAS") OR (TEST_INTERFACE STREQUAL "BLIS_TYPED")) ) +set_property(CACHE TEST_INTERFACE PROPERTY STRINGS "BLAS" "BLAS_BLIS_IMPL" "CBLAS" "BLIS_TYPED") +if( NOT ((TEST_INTERFACE STREQUAL "BLAS") OR (TEST_INTERFACE STREQUAL "BLAS_BLIS_IMPL") OR (TEST_INTERFACE STREQUAL "CBLAS") OR (TEST_INTERFACE STREQUAL "BLIS_TYPED")) ) message(FATAL_ERROR "TEST_INTERFACE option ${TEST_INTERFACE} is not supported. Please use on of the following options \ - during CMake invokation: BLAS, CBLAS, BLIS_TYPED") + during CMake invokation: BLAS, BLAS_BLIS_IMPL, CBLAS, BLIS_TYPED") endif() # Use BLIS_ELEMENT_TYPE to set whether the elements of any matrix/vector tested are integers or floating point values. @@ -139,95 +147,102 @@ if( NOT ((BLIS_ELEMENT_TYPE STREQUAL "f") OR (BLIS_ELEMENT_TYPE STREQUAL "i")) ) during CMake invokation: f, i") endif() -if(LINUX) - if(REF_LIB) - get_filename_component(REFLIB_PATH ${REF_LIB}/.. ABSOLUTE) - get_filename_component(library ${REF_LIB} NAME) - find_library(reflib NAMES ${library} PATHS ${REFLIB_PATH} NO_DEFAULT_PATH) - if(${reflib} STREQUAL reflib-NOTFOUND) - message(FATAL_ERROR "Reference Library not found : " ${REF_LIB}) - else() - message(STATUS "Found Reference Library : " ${reflib}) - endif() +# Option to enable testing with upper case character arguments in BLAS and BLIS calls. +option(TEST_UPPERCASE_ARGS "Test upper case character arguments" OFF) + +# Option to enable testing with thresholds set to zero. +option(THRESHOLD_ZERO "Set thresholds to zero" OFF) + +# Can we test the value of info stored within BLIS and returned by a call to +# bli_info_get_info_value (introduced at AMD BLAS 4.2). +option(CAN_TEST_INFO_VALUE "Can test value of info" ON) + +# Use EXT_VAL to get the extreme value (NaN or Inf) used for testing data that shouldn't be read. +set(EXT_VAL "Inf" CACHE STRING "Extreme value (NaN or Inf) used for testing data that shouldn't be read") +# Set the possible values of reference CBLAS for cmake-gui +set_property(CACHE EXT_VAL PROPERTY STRINGS "NaN" "Inf") +if( NOT ((EXT_VAL STREQUAL "NaN") OR (EXT_VAL STREQUAL "Inf")) ) + message(FATAL_ERROR "EXT_VAL option '${EXT_VAL}' is not supported. Please use one of the following options \ + during CMake invokation: NaN, Inf") +endif() + +# Option to enable testing of input arguments to BLAS APIs. +# Note: This imposes a significant runtime overhead. +option(TEST_INPUT_ARGS "Test input arguments" OFF) + +if(REF_LIB) + get_filename_component(REFLIB_PATH ${REF_LIB}/.. ABSOLUTE) + get_filename_component(library ${REF_LIB} NAME) + find_library(reflib NAMES ${library} PATHS ${REFLIB_PATH} NO_DEFAULT_PATH) + if(${reflib} STREQUAL reflib-NOTFOUND) + message(FATAL_ERROR "Reference Library not found : " ${REF_LIB}) else() - # Use REF_BLAS to set the library that will be used for reference results. - set(REF_CBLAS CACHE STRING "Library used to compute reference results.") - # Set the possible values of theading libraries for cmake-gui - set_property(CACHE REF_CBLAS PROPERTY STRINGS "OpenBLAS" "Netlib" "MKL") - if(NOT ((REF_CBLAS STREQUAL "OpenBLAS") OR (REF_CBLAS STREQUAL "Netlib") OR(REF_CBLAS STREQUAL "MKL"))) - message(FATAL_ERROR "REF_CBLAS option '${REF_CBLAS}' is not supported. Please, use one of the following options \ - during CMake invokation: OpenBLAS, Netlib, MKL or modify CMakeLists.txt to include this option.") - endif() - if(REF_CBLAS STREQUAL "OpenBLAS") - if(NOT(OPENBLAS_PATH)) - message(FATAL_ERROR "Need to provide an OpenBLAS installation path \ - during CMake invokation when OpenBLAS is used for reference results. Please use \ - $ cmake .. -DOPENBLAS_PATH=/home/username/openblas_installation") - endif() - find_library(reflib NAMES openblas PATHS ${OPENBLAS_PATH} NO_DEFAULT_PATH) - if(${reflib} STREQUAL reflib-NOTFOUND) - message(FATAL_ERROR "OpenBLAS Reference Library not found : " ${OPENBLAS_PATH}) - else() - message(STATUS "Found OpenBLAS Reference Library : " ${reflib}) - endif() - set(REF_LIB ${reflib}) - elseif(REF_CBLAS STREQUAL "Netlib") - if(NOT(NETLIB_PATH)) - message(FATAL_ERROR "Need to provide a Netlib installation path \ - during CMake invokation when Netlib is used for reference results. Please use \ - $ cmake .. -DNETLIB_PATH=/home/username/netlib_installation") - endif() - if(INT_SIZE STREQUAL "32") - find_library(netlib NAMES cblas PATHS ${NETLIB_PATH} NO_DEFAULT_PATH) - else() - find_library(netlib NAMES cblas64 PATHS ${NETLIB_PATH} NO_DEFAULT_PATH) - endif() - if(${netlib} STREQUAL netlib-NOTFOUND) - message(FATAL_ERROR "Netlib Reference Library not found : " ${NETLIB_PATH}) - else() - message(STATUS "Found Netlib Reference Library : " ${netlib}) - endif() - set(REF_LIB ${netlib}) - elseif(REF_CBLAS STREQUAL "MKL") - set(MKL_PATH $ENV{MKLROOT}/lib/intel64 - CACHE STRING "The path to MKL.") - find_library(mkllib NAMES mkl_rt PATHS ${MKL_PATH} NO_DEFAULT_PATH) - if(${mkllib} STREQUAL mkllib-NOTFOUND) - message(FATAL_ERROR "MKL Reference Library not found : " ${MKL_PATH}) - else() - message(STATUS "Found MKL Reference Library : " ${mkllib}) - endif() - set(REF_LIB ${mkllib}) - else() - message(FATAL_ERROR "Need to set up a reference library. Please use on of the following options \ - during CMake invokation: -DREF_CBLAS=Netlib or -DREF_CBLAS=OpenBLAS or -DREF_CBLAS=MKL") - endif() + message(STATUS "Found Reference Library : " ${reflib}) endif() -else() #WIN32 - # Use REF_BLAS to set the library that will be used for reference results. - set(REF_CBLAS CACHE STRING "Library used to compute reference results.") +else() # Set the possible values of theading libraries for cmake-gui - set_property(CACHE REF_CBLAS PROPERTY STRINGS "OpenBLAS" "MKL") - if(NOT ((REF_CBLAS STREQUAL "OpenBLAS") OR (REF_CBLAS STREQUAL "MKL"))) + set_property(CACHE REF_CBLAS PROPERTY STRINGS "OpenBLAS" "Netlib" "MKL") + if(NOT ((REF_CBLAS STREQUAL "OpenBLAS") OR (REF_CBLAS STREQUAL "Netlib") OR(REF_CBLAS STREQUAL "MKL"))) message(FATAL_ERROR "REF_CBLAS option '${REF_CBLAS}' is not supported. Please, use one of the following options \ - during CMake invokation: OpenBLAS, MKL or modify CMakeLists.txt to include this option.") + during CMake invokation: OpenBLAS, Netlib, MKL or modify CMakeLists.txt to include this option.") endif() + + if(LINUX) + set(CMAKE_FIND_LIBRARY_PREFIXES "lib") + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so") + set(LIBOpenBLAS openblas) + set(LIBCLAS cblas) + set(LIBCLAS64 cblas64) + set(LIBMKL mkl_rt) + else() + set(CMAKE_FIND_LIBRARY_PREFIXES "") + set(CMAKE_FIND_LIBRARY_SUFFIXES ".dll") + set(LIBOpenBLAS libopenblas) + set(LIBMKL mkl_rt.2) + endif() + if(REF_CBLAS STREQUAL "OpenBLAS") if(NOT(OPENBLAS_PATH)) message(FATAL_ERROR "Need to provide an OpenBLAS installation path \ - during CMake invokation when OpenBLAS is used for reference results. Please use \ - $ cmake .. -DOPENBLAS_PATH=/home/username/openblas_installation") + during CMake invokation when OpenBLAS is used for reference results. Please use \ + $ cmake .. -DOPENBLAS_PATH=/home/username/openblas_installation") + endif() + find_library(reflib NAMES ${LIBOpenBLAS} PATHS ${OPENBLAS_PATH} NO_DEFAULT_PATH) + if(${reflib} STREQUAL reflib-NOTFOUND) + message(FATAL_ERROR "OpenBLAS Reference Library not found : " ${OPENBLAS_PATH}) + else() + message(STATUS "Found OpenBLAS Reference Library : " ${reflib}) + endif() + set(REF_LIB ${reflib}) + elseif(REF_CBLAS STREQUAL "Netlib") + if(NOT(NETLIB_PATH)) + message(FATAL_ERROR "Need to provide a Netlib installation path \ + during CMake invokation when Netlib is used for reference results. Please use \ + $ cmake .. -DNETLIB_PATH=/home/username/netlib_installation") + endif() + if(INT_SIZE STREQUAL "32") + find_library(netlib NAMES ${LIBCLAS} PATHS ${NETLIB_PATH} NO_DEFAULT_PATH) + else() + find_library(netlib NAMES ${LIBCLAS64} PATHS ${NETLIB_PATH} NO_DEFAULT_PATH) endif() - set(REF_LIB "${OPENBLAS_PATH}/libopenblas.dll" CACHE STRING "Reference OpenBLAS Library") - message(STATUS "Found OpenBLAS Reference Library : " ${REF_LIB}) + if(${netlib} STREQUAL netlib-NOTFOUND) + message(FATAL_ERROR "Netlib Reference Library not found : " ${NETLIB_PATH}) + else() + message(STATUS "Found Netlib Reference Library : " ${netlib}) + endif() + set(REF_LIB ${netlib}) elseif(REF_CBLAS STREQUAL "MKL") - if(NOT(MKL_PATH)) - message(FATAL_ERROR "Need to provide an MKL_PATH installation path \ - during CMake invokation when MKL] is used for reference results. Please use \ - $ cmake .. -DMKL_PATH=/home/username/path_to_mkl_rt") + set(MKL_PATH $ENV{MKLROOT}/lib/intel64 CACHE STRING "The path to MKL.") + find_library(mkllib NAMES ${LIBMKL} PATHS ${MKL_PATH} NO_DEFAULT_PATH) + if(${mkllib} STREQUAL mkllib-NOTFOUND) + message(FATAL_ERROR "MKL Reference Library not found : " ${MKL_PATH}) + else() + message(STATUS "Found MKL Reference Library : " ${mkllib}) endif() - set(REF_LIB "${MKL_PATH}/mkl_rt.2.dll" CACHE STRING "Reference MKL Library") - message(STATUS "Found MKL Reference Library : " ${REF_LIB}) + set(REF_LIB ${mkllib}) + else() + message(FATAL_ERROR "Need to set up a reference library. Please use on of the following options \ + during CMake invokation: -DREF_CBLAS=Netlib or -DREF_CBLAS=OpenBLAS or -DREF_CBLAS=MKL") endif() endif() @@ -273,12 +288,13 @@ if(LINUX) add_compile_options(-g -Wall -Wno-unused-function -Wfatal-errors -fPIC ) if(ENABLE_ASAN) - add_compile_options(-fsanitize=address) - add_definitions(-DENABLE_ASAN) + set(ASAN_FLAGS "-fsanitize=address") + list(APPEND CMAKE_C_FLAGS ${ASAN_FLAGS}) endif() if(ENABLE_COVERAGE) - set(CMAKE_CXX_FLAGS "-O0 --coverage") + set(COVERAGE_FLAGS "-O0 --coverage") + list(APPEND CMAKE_C_FLAGS ${COVERAGE_FLAGS}) endif() endif() @@ -292,6 +308,47 @@ if(WIN32) endif() endif() +# The following part will be used to set up a list of defines that dictate +# which kernel tests can be build and run on the current architecture. +# Given that the symbols of kernel functions are not exported for shared libraries +# we only set up those defines for static libs. +# This way, kernel tests won't be compiled/run for shared versions of BLIS. +if(BLIS_LINKING_TYPE STREQUAL "static") + if(ENABLE_THREADING STREQUAL "openmp") + try_run(RUNRESULT COMPILERESULT "${CMAKE_BINARY_DIR}/temp" SOURCES ${CMAKE_SOURCE_DIR}/cmake/config_ukr_tests.cpp + COMPILE_DEFINITIONS -I${BLIS_PATH}/include/ -I${BLIS_PATH}/include/blis + LINK_LIBRARIES ${BLIS_LIBRARY} ${COMMON_LIBS} OpenMP::OpenMP_CXX ${ASAN_FLAGS} ${COVERAGE_FLAGS} + RUN_OUTPUT_VARIABLE UKR_CONFIG + COMPILE_OUTPUT_VARIABLE COMP_VAR + ) + else() + try_run(RUNRESULT COMPILERESULT "${CMAKE_BINARY_DIR}/temp" SOURCES ${CMAKE_SOURCE_DIR}/cmake/config_ukr_tests.cpp + COMPILE_DEFINITIONS -I${BLIS_PATH}/include/ -I${BLIS_PATH}/include/blis + LINK_LIBRARIES ${BLIS_LIBRARY} ${COMMON_LIBS} ${ASAN_FLAGS} ${COVERAGE_FLAGS} + RUN_OUTPUT_VARIABLE UKR_CONFIG + COMPILE_OUTPUT_VARIABLE COMP_VAR + ) + endif() + # Uncomment this to debug this snippet above, if necessary. + if(NOT COMPILERESULT) + message(FATAL_ERROR "Compiling config_ukr_tests.cpp failed with the following error ${COMP_VAR}.") + endif() + # Remove all empty items from the list. + list(REMOVE_ITEM UKR_CONFIG "") + # We iterate through the list returned from the snippet above. + # For example, UKR_CONFIG = AVX2FMA3 for zen3 + # or UKR_CONFIG = AVX2FMA3;AVX512;AVX512VNNI;AVX512BF16 for zen4 + # Depending on the values of this list we define corresponding macros + # -DGTEST_AVX2FMA3 on zen3 + # or -DGTEST_AVX2FMA3;-DGTEST_AVX512;-DGTEST_AVX512VNNI;-DGTEST_AVX512BF16 on zen4 + # Those macros are passed when compiling the tests in testsuite/CMakeLists.txt. + foreach(ukrconf ${UKR_CONFIG}) + list(APPEND UKR_DEFINES "-DGTEST_${ukrconf}") + endforeach() + message(STATUS "Since BLIS GTestSuite is used to check the static version of blis, all kernel tests are enabled.") +else() + message(WARNING "Since BLIS GTestSuite is used to check the shared version of blis, all kernel tests are disabled.") +endif() add_subdirectory(testinghelpers) add_subdirectory(testsuite) diff --git a/gtestsuite/CMakePresets.json b/gtestsuite/CMakePresets.json new file mode 100644 index 0000000000..e1ebfe8495 --- /dev/null +++ b/gtestsuite/CMakePresets.json @@ -0,0 +1,79 @@ +{ + "version": 6, + "include": [ + "cmake/presets/base.json", + "cmake/presets/linux-make.json", + "cmake/presets/linux-ninja.json", + "cmake/presets/win-msvc.json", + "cmake/presets/win-ninja.json" + ], + "configurePresets": [ + { + "name": "linux-base", + "hidden": true + }, + { + "name": "linux-st-lp64-auto-shared", + "description": "Configure for serial LP64 BLIS with on Linux", + "inherits": ["linux-base", "st", "lp64"], + "hidden": false, + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-st-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + } + ], + "buildPresets": [ + { + "name": "linux-st-lp64-auto-shared", + "description": "Build GTestSuite using serial LP64 BLIS on Linux", + "configurePreset": "linux-st-lp64-auto-shared", + "jobs": 0 + } + ], + "testPresets":[ + { + "name":"testall", + "description": "Run all tests", + "configurePreset": "linux-st-lp64-auto-shared", + "output": {"outputOnFailure": false} + }, + { + "name":"level3", + "description": "Run level3 tests only", + "configurePreset": "linux-st-lp64-auto-shared", + "output": {"outputOnFailure": false}, + "filter": { + "include": { + "name": "level3" + }, + "exclude": { + "name":"gemm|trsm" + } + } + } + ], + + "workflowPresets": [ + { + "name": "linux-st-lp64-auto-shared-check", + "description": "Build and check single-threaded shared BLIS for auto configuration on Linux", + "steps": [ + { + "type": "configure", + "name": "linux-st-lp64-auto-shared" + }, + { + "type": "build", + "name": "linux-st-lp64-auto-shared" + }, + { + "type": "test", + "name": "level3" + } + ] + } + ] +} diff --git a/gtestsuite/README.md b/gtestsuite/README.md index b5d801e56f..801ae01a37 100644 --- a/gtestsuite/README.md +++ b/gtestsuite/README.md @@ -86,7 +86,13 @@ For threaded MKL the following OpenMP runtimes are used: * For testing a 64-bit integer BLIS library, use `-DINT_SIZE=64`. ## Address Sanitizer (Linux Only) * To build using address sanitizer, configure using `-DENABLE_ASAN=ON`. [**OFF by default**] -* An installation to BLIS which was build with ASAN flags[CFLAGS="-O0 -g -fsanitize=address"] needs to be provided. +* An installation to BLIS which was build with ASAN flags needs to be provided. +* Set -DENABLE_ASAN=ON when building BLIS with CMake, or set CFLAGS="-O0 -g -fsanitize=address" when building with make. +* By default redzone size is 16 bytes and can redzone size can be increase to 2048 bytes. +```console +$ ASAN_OPTIONS=redzone=2048 +``` + ## Code Coverage (Only GCC Compiler) * BLIS : Configure BLIS Library with code coverage flags[CFLAGS="-O0 -fprofile-arcs -ftest-coverage"], compile and install. * Gtestsuite : To build for code coverage, configure cmake with `-DENABLE_COVERAGE=ON`. [**OFF by default**] and then compile and run the executable. @@ -101,7 +107,39 @@ For threaded MKL the following OpenMP runtimes are used: ## BLIS Library Interface to be Tested * To build the testsuite using BLAS interface, configure using `-DTEST_INTERFACE=BLAS`. [**Default**] * To build the testsuite using CBLAS interface, configure using `-DTEST_INTERFACE=CBLAS`. +* To build the testsuite using BLAS_BLIS_IMPL wrapper layer (called underneath BLAS and CBLAS interfaces), configure using `-DTEST_INTERFACE=BLAS_BLIS_IMPL`. * To build the testsuite using BLIS-typed interface, configure using `-DTEST_INTERFACE=BLIS_TYPED`. Note that more tests are built for this option, due to the extended APIs. +## Test with upper case character arguments +* To test with upper case character arguments, configure using `-DTEST_UPPERCASE_ARGS=ON`. [**OFF by default**] +## Test with threshold set to zero +* To enable testing with the threshold set to zero, configure using `-DTHRESHOLD_ZERO=ON`. [**OFF by default**] +## Type of Data Generated in Testing +* To generate floating-point numbers in the matrices and vectors that are used in testing, configure using `-DBLIS_ELEMENT_TYPE=f`. [**Default**] +* To generate integers in the matrices and vectors that are used in testing, configure using `-DBLIS_ELEMENT_TYPE=i`. This can be useful for debugging since operating on integers should compute exact results. Note that "integer" here doesn't refer to `int` type, but on the mathematical set Z. +## Extreme value used for testing data that shouldn't be read. +* To test with Inf, configure using `-DEXT_VAL=Inf`. [**Default**] +* To test with NaN, configure using `-DEXT_VAL=NaN`. + +This option is used to set a static constant variable `GenericET` of type `testinghelpers::datagenerators::ElementType` which is in turned used as the default argument in data generator functions such as `get_random_vector`, `get_random_matrix`, etc. To find a full list of APIs that can be used to generate random data we refer to `blis/gtestsuite/testinghelpers/inc/common/data_generators.h`. +### Specifying Types of Data Independent of BLIS_ELEMENT_TYPE +* To generate a vector x with random values in [-10, 10], depending on `BLIS_ELEMENT_TYPE` use +```cpp +std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); +``` +* To generate a vector x with floating-point values in [-10, 10], independent of `BLIS_ELEMENT_TYPE` use +```cpp +std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx, testinghelpers::datagenerators::ElementType::FP ); +``` +* To generate a vector x with integer values in [-10, 10], independent of `BLIS_ELEMENT_TYPE` use +```cpp +std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx, testinghelpers::datagenerators::ElementType::INT ); +``` +## Testing value of INFO set within BLIS. This is not returned by BLAS or CBLAS APIs, but AMD BLAS 4.2 and later includes a function bli_info_get_info_value to return this value. +* If using an older version of BLIS, configure using `-DCAN_TEST_INFO_VALUE=OFF`. [**ON by default**] + +## Test BLAS input arguments +* To check input arguments have not been changed by the BLAS routines, configure using `-DTEST_INPUT_ARGS=ON`. [**OFF by default**] +* Note: this will substantially increase the runtime of the tests. # Building the Tests After the successful configuration of CMake, we can build the tests. The following steps are taken by the building process: @@ -162,9 +200,9 @@ You can also find more details in [CMake Documentation](https://cmake.org/cmake/ ## Using the Executables As we mentioned earlier, all cpp files of each API directory are compiled into one executable. This executable can be run separately which can be very useful while developing or debugging. When MKL is used as a reference, the following environment variables need to be set before calling the executables, depending on the configuration. -* MKL_INTERFACE_LAYER=LP64 or MKL_INTERFACE_LAYER=ILP64 depending on whether 32 or 64 bit integers are used, respectivelly. +* MKL_INTERFACE_LAYER=LP64 or MKL_INTERFACE_LAYER=ILP64 depending on whether 32 or 64 bit integers are used, respectively. * MKL_THREADING_LAYER=SEQUENTIAL for sequential MKL. -* MKL_THREADING_LAYER=INTEL or MKL_THREADING_LAYER=GNU depending on whether we execute on Windows or on Linux, respectivelly. +* MKL_THREADING_LAYER=INTEL or MKL_THREADING_LAYER=GNU depending on whether we execute on Windows or on Linux, respectively. ### To run all addv tests use: ```console @@ -174,6 +212,13 @@ $ ./testsuite.level1.addv ```console $ ./testuite.util.nrm2 --gtest_filter="*snrm2*" ``` +Alternatively, use the GTEST_FILTER environment variable. This is particularly useful for +passing gtest filter options to executables run via ctest, e.g.: +```console +$ GTEST_FILTER="*snrm2*" ./testuite.util.nrm2 +$ GTEST_FILTER=-"EVT" ctest -R level2 +``` + ## Running tests using Valgrind We can run any executable using valgrind as usual. For example, use the following command ```console diff --git a/gtestsuite/cmake/config_ukr_tests.cpp b/gtestsuite/cmake/config_ukr_tests.cpp new file mode 100644 index 0000000000..ced6eccad5 --- /dev/null +++ b/gtestsuite/cmake/config_ukr_tests.cpp @@ -0,0 +1,53 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include + +/** + * Small program that uses blis library to check if specific instructions + * are supported. This is compiled and run during CMake configuration and + * the output is used to define macros that are used for kernel testing. + * We MUST use ";" to create a list in CMake so make sure to add them in + * the future if more instructions are added. + * + * Note that this is only available on static blis since those symbols aren't + * exported for shared libraries. +*/ +int main() +{ + if(bli_cpuid_is_avx2fma3_supported()) std::cout<<"AVX2FMA3;"; + if(bli_cpuid_is_avx512_supported()) std::cout<<"AVX512;"; + if(bli_cpuid_is_avx512vnni_supported()) std::cout<<"AVX512VNNI;"; + if(bli_cpuid_is_avx512bf16_supported()) std::cout<<"AVX512BF16"; +} diff --git a/gtestsuite/cmake/presets/base.json b/gtestsuite/cmake/presets/base.json new file mode 100644 index 0000000000..0d3f651125 --- /dev/null +++ b/gtestsuite/cmake/presets/base.json @@ -0,0 +1,67 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "lp64", + "hidden": true, + "cacheVariables": { + "INT_SIZE": "32" + } + }, + { + "name": "ilp64", + "hidden": true, + "cacheVariables": { + "INT_SIZE": "64" + } + }, + { + "name": "st", + "hidden": true, + "cacheVariables": { + "ENABLE_THREADING": "no" + } + }, + { + "name": "mt", + "hidden": true, + "cacheVariables": { + "ENABLE_THREADING": "openmp" + } + }, + { + "name": "amdzen", + "hidden": true + }, + { + "name": "auto", + "hidden": true + }, + { + "name": "static", + "hidden": true, + "cacheVariables": { + "BLIS_LINKING_TYPE": "static" + } + }, + { + "name": "shared", + "hidden": true, + "cacheVariables": { + "BLIS_LINKING_TYPE": "shared" + } + }, + { + "name": "base", + "hidden": true, + "binaryDir": "${sourceDir}/build-${presetName}" + } + ], + "buildPresets": [ + { + "name": "base", + "configurePreset": "base", + "jobs": 0 + } + ] +} diff --git a/gtestsuite/cmake/presets/linux-make.json b/gtestsuite/cmake/presets/linux-make.json new file mode 100644 index 0000000000..4783a38395 --- /dev/null +++ b/gtestsuite/cmake/presets/linux-make.json @@ -0,0 +1,261 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-make", + "inherits": "base", + "hidden": true, + "generator": "Unix Makefiles", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-make-st-lp64-amdzen-static", + "inherits": ["linux-make", "st", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-lp64-amdzen-shared", + "inherits": ["linux-make", "st", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-lp64-amdzen-static", + "inherits": ["linux-make", "mt", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-lp64-amdzen-shared", + "inherits": ["linux-make", "mt", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-ilp64-amdzen-static", + "inherits": ["linux-make", "st", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-ilp64-amdzen-shared", + "inherits": ["linux-make", "st", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-ilp64-amdzen-static", + "inherits": ["linux-make", "mt", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared", + "inherits": ["linux-make", "mt", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-lp64-auto-static", + "inherits": ["linux-make", "st", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-lp64-auto-shared", + "inherits": ["linux-make", "st", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-lp64-auto-static", + "inherits": ["linux-make", "mt", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-lp64-auto-shared", + "inherits": ["linux-make", "mt", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-ilp64-auto-static", + "inherits": ["linux-make", "st", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-st-ilp64-auto-shared", + "inherits": ["linux-make", "st", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-ilp64-auto-static", + "inherits": ["linux-make", "mt", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-make-mt-ilp64-auto-shared", + "inherits": ["linux-make", "mt", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + } + ], + "buildPresets": [ + { + "name": "linux-make-st-lp64-amdzen-static", + "configurePreset": "linux-make-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-amdzen-shared", + "configurePreset": "linux-make-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-amdzen-static", + "configurePreset": "linux-make-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-amdzen-shared", + "configurePreset": "linux-make-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-amdzen-static", + "configurePreset": "linux-make-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-amdzen-shared", + "configurePreset": "linux-make-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-amdzen-static", + "configurePreset": "linux-make-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-amdzen-shared", + "configurePreset": "linux-make-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-auto-static", + "configurePreset": "linux-make-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-st-lp64-auto-shared", + "configurePreset": "linux-make-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-auto-static", + "configurePreset": "linux-make-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-lp64-auto-shared", + "configurePreset": "linux-make-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-auto-static", + "configurePreset": "linux-make-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-st-ilp64-auto-shared", + "configurePreset": "linux-make-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-auto-static", + "configurePreset": "linux-make-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-make-mt-ilp64-auto-shared", + "configurePreset": "linux-make-mt-lp64-auto-shared", + "inherits": "base" + } + ] +} diff --git a/gtestsuite/cmake/presets/linux-ninja.json b/gtestsuite/cmake/presets/linux-ninja.json new file mode 100644 index 0000000000..c3d494decc --- /dev/null +++ b/gtestsuite/cmake/presets/linux-ninja.json @@ -0,0 +1,261 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "linux-ninja", + "inherits": "base", + "hidden": true, + "generator": "Ninja", + "condition": { + "type": "notEquals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "linux-ninja-st-lp64-amdzen-static", + "inherits": ["linux-ninja", "st", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared", + "inherits": ["linux-ninja", "st", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static", + "inherits": ["linux-ninja", "mt", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared", + "inherits": ["linux-ninja", "mt", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static", + "inherits": ["linux-ninja", "st", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared", + "inherits": ["linux-ninja", "st", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static", + "inherits": ["linux-ninja", "mt", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared", + "inherits": ["linux-ninja", "mt", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-lp64-auto-static", + "inherits": ["linux-ninja", "st", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-lp64-auto-shared", + "inherits": ["linux-ninja", "st", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-lp64-auto-static", + "inherits": ["linux-ninja", "mt", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-lp64-auto-shared", + "inherits": ["linux-ninja", "mt", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-ilp64-auto-static", + "inherits": ["linux-ninja", "st", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-st-ilp64-auto-shared", + "inherits": ["linux-ninja", "st", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-ilp64-auto-static", + "inherits": ["linux-ninja", "mt", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared", + "inherits": ["linux-ninja", "mt", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + } + ], + "buildPresets": [ + { + "name": "linux-ninja-st-lp64-amdzen-static", + "configurePreset": "linux-ninja-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-amdzen-shared", + "configurePreset": "linux-ninja-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-amdzen-static", + "configurePreset": "linux-ninja-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-amdzen-shared", + "configurePreset": "linux-ninja-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-amdzen-static", + "configurePreset": "linux-ninja-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-amdzen-shared", + "configurePreset": "linux-ninja-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-static", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-amdzen-shared", + "configurePreset": "linux-ninja-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-auto-static", + "configurePreset": "linux-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-lp64-auto-shared", + "configurePreset": "linux-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-auto-static", + "configurePreset": "linux-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-lp64-auto-shared", + "configurePreset": "linux-ninja-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-auto-static", + "configurePreset": "linux-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-st-ilp64-auto-shared", + "configurePreset": "linux-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-auto-static", + "configurePreset": "linux-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "linux-ninja-mt-ilp64-auto-shared", + "configurePreset": "linux-ninja-mt-lp64-auto-shared", + "inherits": "base" + } + ] +} diff --git a/gtestsuite/cmake/presets/win-msvc.json b/gtestsuite/cmake/presets/win-msvc.json new file mode 100644 index 0000000000..d316161d12 --- /dev/null +++ b/gtestsuite/cmake/presets/win-msvc.json @@ -0,0 +1,262 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "win-msvc", + "inherits": "base", + "hidden": true, + "generator": "Visual Studio 17 2022", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + }, + "toolset": "ClangCl" + }, + { + "name": "win-msvc-st-lp64-amdzen-static", + "inherits": ["win-msvc", "st", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-lp64-amdzen-shared", + "inherits": ["win-msvc", "st", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-lp64-amdzen-static", + "inherits": ["win-msvc", "mt", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared", + "inherits": ["win-msvc", "mt", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-ilp64-amdzen-static", + "inherits": ["win-msvc", "st", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared", + "inherits": ["win-msvc", "st", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static", + "inherits": ["win-msvc", "mt", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared", + "inherits": ["win-msvc", "mt", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-lp64-auto-static", + "inherits": ["win-msvc", "st", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-lp64-auto-shared", + "inherits": ["win-msvc", "st", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-lp64-auto-static", + "inherits": ["win-msvc", "mt", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-lp64-auto-shared", + "inherits": ["win-msvc", "mt", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-ilp64-auto-static", + "inherits": ["win-msvc", "st", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-st-ilp64-auto-shared", + "inherits": ["win-msvc", "st", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-ilp64-auto-static", + "inherits": ["win-msvc", "mt", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-msvc-mt-ilp64-auto-shared", + "inherits": ["win-msvc", "mt", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-win-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + } + ], + "buildPresets": [ + { + "name": "win-msvc-st-lp64-amdzen-static", + "configurePreset": "win-msvc-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-amdzen-shared", + "configurePreset": "win-msvc-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-amdzen-static", + "configurePreset": "win-msvc-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-amdzen-shared", + "configurePreset": "win-msvc-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-amdzen-static", + "configurePreset": "win-msvc-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-amdzen-shared", + "configurePreset": "win-msvc-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-amdzen-static", + "configurePreset": "win-msvc-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-amdzen-shared", + "configurePreset": "win-msvc-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-auto-static", + "configurePreset": "win-msvc-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-lp64-auto-shared", + "configurePreset": "win-msvc-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-auto-static", + "configurePreset": "win-msvc-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-lp64-auto-shared", + "configurePreset": "win-msvc-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-auto-static", + "configurePreset": "win-msvc-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-st-ilp64-auto-shared", + "configurePreset": "win-msvc-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-auto-static", + "configurePreset": "win-msvc-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-msvc-mt-ilp64-auto-shared", + "configurePreset": "win-msvc-mt-lp64-auto-shared", + "inherits": "base" + } + ] +} diff --git a/gtestsuite/cmake/presets/win-ninja.json b/gtestsuite/cmake/presets/win-ninja.json new file mode 100644 index 0000000000..cc47119cde --- /dev/null +++ b/gtestsuite/cmake/presets/win-ninja.json @@ -0,0 +1,261 @@ +{ + "version": 6, + "include": [ + "base.json" + ], + "configurePresets": [ + { + "name": "win-ninja", + "inherits": "base", + "hidden": true, + "generator": "Ninja", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "win-ninja-st-lp64-amdzen-static", + "inherits": ["win-ninja", "st", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-lp64-amdzen-shared", + "inherits": ["win-ninja", "st", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-lp64-amdzen-static", + "inherits": ["win-ninja", "mt", "lp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared", + "inherits": ["win-ninja", "mt", "lp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-ilp64-amdzen-static", + "inherits": ["win-ninja", "st", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared", + "inherits": ["win-ninja", "st", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static", + "inherits": ["win-ninja", "mt", "ilp64", "amdzen", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared", + "inherits": ["win-ninja", "mt", "ilp64", "amdzen", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-amdzen", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-lp64-auto-static", + "inherits": ["win-ninja", "st", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-lp64-auto-shared", + "inherits": ["win-ninja", "st", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-lp64-auto-static", + "inherits": ["win-ninja", "mt", "lp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-lp64-auto-shared", + "inherits": ["win-ninja", "mt", "lp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-lp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-ilp64-auto-static", + "inherits": ["win-ninja", "st", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-st-ilp64-auto-shared", + "inherits": ["win-ninja", "st", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-ilp64-auto-static", + "inherits": ["win-ninja", "mt", "ilp64", "auto", "static"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "static", + "REF_CBLAS": "MKL" + } + }, + { + "name": "win-ninja-mt-ilp64-auto-shared", + "inherits": ["win-ninja", "mt", "ilp64", "auto", "shared"], + "hidden": false, + "cacheVariables": { + "BLIS_PATH": "${sourceParentDir}//install-linux-ilp64-auto", + "BLIS_LINKING_TYPE": "shared", + "REF_CBLAS": "MKL" + } + } + ], + "buildPresets": [ + { + "name": "win-ninja-st-lp64-amdzen-static", + "configurePreset": "win-ninja-st-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-amdzen-shared", + "configurePreset": "win-ninja-st-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-amdzen-static", + "configurePreset": "win-ninja-mt-lp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-amdzen-shared", + "configurePreset": "win-ninja-mt-lp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-amdzen-static", + "configurePreset": "win-ninja-st-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-amdzen-shared", + "configurePreset": "win-ninja-st-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-amdzen-static", + "configurePreset": "win-ninja-mt-ilp64-amdzen-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-amdzen-shared", + "configurePreset": "win-ninja-mt-ilp64-amdzen-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-auto-static", + "configurePreset": "win-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-lp64-auto-shared", + "configurePreset": "win-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-auto-static", + "configurePreset": "win-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-lp64-auto-shared", + "configurePreset": "win-ninja-mt-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-auto-static", + "configurePreset": "win-ninja-st-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-st-ilp64-auto-shared", + "configurePreset": "win-ninja-st-lp64-auto-shared", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-auto-static", + "configurePreset": "win-ninja-mt-lp64-auto-static", + "inherits": "base" + }, + { + "name": "win-ninja-mt-ilp64-auto-shared", + "configurePreset": "win-ninja-mt-lp64-auto-shared", + "inherits": "base" + } + ] +} diff --git a/gtestsuite/codecov.sh b/gtestsuite/codecov.sh index da8cff3022..33cfe539f3 100755 --- a/gtestsuite/codecov.sh +++ b/gtestsuite/codecov.sh @@ -1,4 +1,36 @@ #!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# echo "Code Coverage for BLIS" echo "obj_dir_path : $1" diff --git a/gtestsuite/testinghelpers/CMakeLists.txt b/gtestsuite/testinghelpers/CMakeLists.txt index c6cca616ed..78f459e3e7 100644 --- a/gtestsuite/testinghelpers/CMakeLists.txt +++ b/gtestsuite/testinghelpers/CMakeLists.txt @@ -1,21 +1,22 @@ #[=[ + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ]=] file(GLOB_RECURSE SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/*/*.cpp") @@ -41,7 +43,9 @@ elseif(REF_CBLAS STREQUAL "OpenBLAS") target_compile_definitions(testinghelpers PUBLIC REF_IS_OPENBLAS) endif() if(TEST_INTERFACE STREQUAL "BLAS") - target_compile_definitions(testinghelpers PUBLIC TEST_BLAS) + target_compile_definitions(testinghelpers PUBLIC TEST_BLAS TEST_BLAS_LIKE) +elseif(TEST_INTERFACE STREQUAL "BLAS_BLIS_IMPL") + target_compile_definitions(testinghelpers PUBLIC TEST_BLAS_BLIS_IMPL TEST_BLAS_LIKE) elseif(TEST_INTERFACE STREQUAL "CBLAS") target_compile_definitions(testinghelpers PUBLIC TEST_CBLAS) else() # BLIS_TYPED option @@ -52,7 +56,9 @@ if(INT_SIZE STREQUAL "32") else() target_compile_definitions(testinghelpers PUBLIC INT_SIZE=64) endif() -target_compile_definitions(testinghelpers PUBLIC BLIS_ELEMENT_TYPE='${BLIS_ELEMENT_TYPE}') +if(${BLIS_ELEMENT_TYPE} STREQUAL "i") + target_compile_definitions(testinghelpers PUBLIC -DBLIS_INT_ELEMENT_TYPE) +endif() target_include_directories(testinghelpers PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/inc ${BLIS_INCLUDE}) if(LINUX) target_link_libraries(testinghelpers pthread) @@ -64,5 +70,5 @@ else() set(threads_spec Threads::Threads) endif() target_link_libraries(testinghelpers PUBLIC ${threads_spec}) - set_target_properties(testinghelpers PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(testinghelpers PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() diff --git a/gtestsuite/testinghelpers/inc/common/complex_helpers.h b/gtestsuite/testinghelpers/inc/common/complex_helpers.h index 588144f7f5..8475c3fe81 100644 --- a/gtestsuite/testinghelpers/inc/common/complex_helpers.h +++ b/gtestsuite/testinghelpers/inc/common/complex_helpers.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -56,6 +56,9 @@ dcomplex operator-(const dcomplex x, const dcomplex y); scomplex operator*(const scomplex x, const scomplex y); dcomplex operator*(const dcomplex x, const dcomplex y); +scomplex operator/(const scomplex x, const scomplex y); +dcomplex operator/(const dcomplex x, const dcomplex y); + bool operator== (const scomplex x, const scomplex y); bool operator== (const dcomplex x, const dcomplex y); diff --git a/gtestsuite/testinghelpers/inc/common/data_generators.h b/gtestsuite/testinghelpers/inc/common/data_generators.h index f40eeba018..3f7db7afe4 100644 --- a/gtestsuite/testinghelpers/inc/common/data_generators.h +++ b/gtestsuite/testinghelpers/inc/common/data_generators.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -32,76 +32,728 @@ */ -#pragma once - #include -#include "common/type_info.h" +#include +#include "common/testing_helpers.h" namespace testinghelpers { namespace datagenerators { +// Setting an enum class to make random data generation more robust. +enum class ElementType {FP, INT}; +// Define a static variable to be used as the default argument in +// the generators, depending on CMake configuration. +#ifdef BLIS_INT_ELEMENT_TYPE +// Integer random values will be used in testing. +static const ElementType GenericET = ElementType::INT; +#else +// Floating-point random values will be used in testing. +static const ElementType GenericET = ElementType::FP; +#endif + /*************************************************** - * Random Generators + * Floating Point Generators ****************************************************/ /** - * @brief Returns a random int/float converted to an fp type (float, double, scomplex, dcomplex) + * @brief Returns a random fp type (float, double, scomplex, dcomplex) * that lies in the range [from, to]. * * @param[in, out] alpha the random fp */ -template -void randomgenerators(int from, int to, T* alpha, char fp); +template +void getfp(T2 from, T3 to, T1* alpha) +{ + using real_T = typename testinghelpers::type_info::real_type; + std::mt19937 generator(94); + std::uniform_real_distribution distr(from, to); + if constexpr (testinghelpers::type_info::is_real) + *alpha = distr(generator); + else + *alpha = {distr(generator), distr(generator)}; +} /** - * @brief Returns a random vector (float, double, scomplex, dcomplex) - * with elements that are integers or floats, depending on char, and follow a uniform distribution in the range [from, to]. + * @brief Returns a random fp vector (float, double, scomplex, dcomplex) + * with elements that follow a uniform distribution in the range [from, to]. * @param[in] n length of vector x * @param[in] incx increments of vector x * @param[in, out] x the random fp vector - * @param[in] fp if fp=='i' the elements will have random integer values. - * if fp=='f' the elements will have random float values. */ -template -void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char fp = BLIS_ELEMENT_TYPE); +template +void getfp(T2 from, T3 to, gtint_t n, gtint_t incx, T1* x) +{ + using real_T = typename testinghelpers::type_info::real_type; + T1* chi; + + if (incx != 1) + { + // First initialize all elements in vector to unusual value to help + // catch if intervening elements have been incorrectly used or modified. + for ( gtint_t i = 0; i < testinghelpers::buff_dim(n, incx); ++i ) + { + chi = x + i; + *chi = T1{-1.2345e38}; + } + } + + // Generate the values from the uniform distribution that + // the BLAS routine should read and/or modify. + std::mt19937 generator(94); + std::uniform_real_distribution distr(from, to); + for ( gtint_t i = 0; i < n; ++i ) + { + chi = x + i*std::abs(incx); + if constexpr (testinghelpers::type_info::is_real) + *chi = distr(generator); + else + *chi = {distr(generator), distr(generator)}; + } +} + +/** + * @brief Returns a random fp vector (float, double, scomplex, dcomplex) + * with elements that follow a uniform distribution in the range [from, to]. + * @param[in] storage storage type of matrix A, row or column major + * @param[in] m, n dimentions of matrix A + * @param[in, out] a the random fp matrix A + * @param[in] lda leading dimension of matrix A + * @param[in] stridea stride between two "continuous" elements in matrix A + */ +template +void getfp(T2 from, T3 to, char storage, gtint_t m, gtint_t n, T1* a, gtint_t lda, gtint_t stridea = 1 ) +{ + using real_T = typename testinghelpers::type_info::real_type; + std::mt19937 generator(1994); + std::uniform_real_distribution distr(from, to); + if((storage == 'c') || (storage == 'C')) + { + if (m > 0) + { + for(gtint_t j=0; j::is_real) + { + for(gtint_t i=0; i 0) + { + for(gtint_t i=0; i::is_real) + { + for(gtint_t j=0; j +void getfp(T2 from, T3 to, char storage, gtint_t m, gtint_t n, T1* a, char transa, gtint_t lda, gtint_t stridea = 1 ) +{ + if( chktrans( transa )) { + swap_dims( &m, &n ); + } + getfp( from, to, storage, m, n, a, lda, stridea ); +} + +/*************************************************** + * Integer Generators +****************************************************/ +/** + * @brief Returns a random integer converted to an fp type (float, double, scomplex, dcomplex) + * that lies in the range [from, to]. + * + * @param[in, out] alpha the random fp + */ +template +void getint(int from, int to, T* alpha) +{ + using real_T = typename testinghelpers::type_info::real_type; + std::mt19937 generator(94); + std::uniform_int_distribution distr(from, to); + if constexpr (testinghelpers::type_info::is_real) + *alpha = real_T(distr(generator)); + else + *alpha = {real_T(distr(generator)), real_T(distr(generator))}; +} +/** + * @brief Returns a random fp vector (float, double, scomplex, dcomplex) + * with elements that are integers and follow a uniform distribution in the range [from, to]. + * @param[in] n length of vector x + * @param[in] incx increments of vector x + * @param[in, out] x the random fp vector + */ template -void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE); +void getint(int from, int to, gtint_t n, gtint_t incx, T* x) +{ + using real_T = typename testinghelpers::type_info::real_type; + T* chi; + if (incx != 1) + { + // First initialize all elements in vector to unusual value to help + // catch if intervening elements have been incorrectly used or modified. + for ( gtint_t i = 0; i < testinghelpers::buff_dim(n, incx); ++i ) + { + chi = x + i; + *chi = T{-1.2345e38}; + } + } + + // Generate the values from the uniform distribution that + // the BLAS routine should read and/or modify. + std::mt19937 generator(94); + std::uniform_int_distribution distr(from, to); + for ( gtint_t i = 0; i < n; ++i ) + { + chi = x + i*std::abs(incx); + if constexpr (testinghelpers::type_info::is_real) + *chi = real_T(distr(generator)); + else + *chi = {real_T(distr(generator)), real_T(distr(generator))}; + } +} + +/** + * @brief Returns a random fp matrix (float, double, scomplex, dcomplex) + * with elements that are integers and follow a uniform distribution in the range [from, to]. + * @param[in] storage storage type of matrix A, row or column major + * @param[in] m, n dimentions of matrix A + * @param[in, out] a the random fp matrix A + * @param[in] lda leading dimension of matrix A + * @param[in] stridea stride between two "continuous" elements in matrix A + */ template -void randomgenerators(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, char fp = BLIS_ELEMENT_TYPE); +void getint(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda, gtint_t stridea = 1 ) +{ + using real_T = typename testinghelpers::type_info::real_type; + std::mt19937 generator(94); + std::uniform_int_distribution distr(from, to); + if((storage == 'c') || (storage == 'C')) + { + if (m > 0) + { + for(gtint_t j=0; j::is_real) + { + for(gtint_t i=0; i 0) + { + for(gtint_t i=0; i::is_real) + { + for(gtint_t j=0; j -void randomgenerators(int from, int to, char storage, char uplo, gtint_t m, - T* a, gtint_t lda, char fp = BLIS_ELEMENT_TYPE ); +void getint(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, gtint_t stridea = 1 ) +{ + if( chktrans( transa )) { + swap_dims( &m, &n ); + } + getint( from, to, storage, m, n, a, lda, stridea ); +} + +template +void randomgenerators(T2 from, T3 to, gtint_t n, gtint_t incx, T1* x, ElementType datatype = GenericET) { + + if( datatype == ElementType::INT ) + getint( from, to, n, incx, x ); + else + getfp( from, to, n, incx, x ); +} + +template +void randomgenerators( T2 from, T3 to, char storage, gtint_t m, gtint_t n, + T1* a, gtint_t lda, gtint_t stridea = 1, ElementType datatype = GenericET ) { + + if( datatype == ElementType::INT ) + getint( from, to, storage, m, n, a, lda, stridea ); + else + getfp( from, to, storage, m, n, a, lda, stridea ); +} + +template +void randomgenerators( T2 from, T3 to, char storage, gtint_t m, gtint_t n, + T1* a, char transa, gtint_t lda, gtint_t stridea = 1, ElementType datatype = GenericET ) { + + if( datatype == ElementType::INT ) + getint( from, to, storage, m, n, a, transa, lda, stridea ); + else + getfp( from, to, storage, m, n, a, transa, lda, stridea ); +} + +template +void randomgenerators( T2 from, T3 to, char storage, char uplo, gtint_t k, + T1* a, gtint_t lda, ElementType datatype = GenericET ) { + testinghelpers::datagenerators::randomgenerators(from, to, storage, k, k, a, lda, 1, datatype); + if( (storage=='c')||(storage=='C') ) + { + for(gtint_t j=0; jj) a[i+j*lda] = T1{2.987e38}; + } + else if ( (uplo=='l')||(uplo=='L') ) + { + if (ij) a[j+i*lda] = T1{2.987e38}; + } + else if ( (uplo=='l')||(uplo=='L') ) + { + if (i -std::vector get_random_matrix(int from, int to, char storage, char trans, gtint_t m, gtint_t n, - gtint_t lda, char datatype = BLIS_ELEMENT_TYPE ); +template +std::vector get_random_matrix(T2 from, T3 to, char storage, char trans, gtint_t m, gtint_t n, + gtint_t lda, gtint_t stridea = 1, datagenerators::ElementType datatype = datagenerators::GenericET) +{ + std::vector a(matsize(storage, trans, m, n, lda)); + testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, a.data(), trans, lda, stridea, datatype ); + return a; +} + +template +std::vector get_random_matrix(T2 from, T3 to, char storage, char uplo, gtint_t k, gtint_t lda, datagenerators::ElementType datatype = datagenerators::GenericET ) +{ + // Create matrix for the given sizes. + std::vector a( testinghelpers::matsize( storage, 'n', k, k, lda ) ); + testinghelpers::datagenerators::randomgenerators( from, to, storage, uplo, k, a.data(), lda, datatype ); + return a; +} + +template +std::vector get_random_vector(T2 from, T3 to, gtint_t n, gtint_t incx, datagenerators::ElementType datatype = datagenerators::GenericET) +{ + // Create vector for the given sizes. + std::vector x( testinghelpers::buff_dim(n, incx) ); + testinghelpers::datagenerators::randomgenerators( from, to, n, incx, x.data(), datatype ); + return x; +} template -std::vector get_random_matrix(int from, int to, char storage, char uplo, gtint_t k, - gtint_t lda, char datatype = BLIS_ELEMENT_TYPE ); +void set_vector( gtint_t n, gtint_t incx, T* x, T value ) +{ + T* chi; + + if (incx != 1) + { + // First initialize all elements in vector to unusual value to help + // catch if intervening elements have been incorrectly used or modified. + for ( gtint_t i = 0; i < testinghelpers::buff_dim(n, incx); ++i ) + { + chi = x + i; + *chi = T{-1.2345e38}; + } + } + + for ( gtint_t i = 0; i < n; ++i ) + { + chi = x + i*std::abs(incx); + *chi = value ; + } +} template -std::vector get_random_vector(int from, int to, gtint_t n, gtint_t incx,char datatype = BLIS_ELEMENT_TYPE); +void set_matrix( char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, T value ) +{ + if( chktrans( transa )) { + swap_dims( &m, &n ); + } + + if((storage == 'c') || (storage == 'C')) + { + for( gtint_t j = 0 ; j < n ; j++ ) + { + for( gtint_t i = 0 ; i < m ; i++ ) + { + a[i+j*lda] = value ; + } + for(gtint_t i=m; i -std::vector get_vector( gtint_t n, gtint_t incx, T value ); +void set_matrix( char storage, gtint_t n, T* a, char uplo, gtint_t lda, T value ) +{ + testinghelpers::set_matrix(storage, n, n, a, 'n', lda, value ); + if( (storage=='c')||(storage=='C') ) + { + for(gtint_t j=0; jj) a[i+j*lda] = T{2.987e38}; + } + else if ( (uplo=='l')||(uplo=='L') ) + { + if (ij) a[j+i*lda] = T{2.987e38}; + } + else if ( (uplo=='l')||(uplo=='L') ) + { + if (i -std::vector get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint_t lda, T value ); +std::vector get_vector( gtint_t n, gtint_t incx, T value ) +{ + // Create vector for the given sizes. + std::vector x( testinghelpers::buff_dim(n, incx) ); + testinghelpers::set_vector( n, incx, x.data(), value ); + return x; +} template -void set_vector( gtint_t n, gtint_t incx, T* x, T value ); +std::vector get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint_t lda, T value ) +{ + std::vector a( matsize( storage, trans, m, n, lda ) ); + testinghelpers::set_matrix( storage, m, n, a.data(), trans, lda, value ); + return a; +} template -void set_matrix( char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, T value ); +void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m ) +{ + // Setting the exception values on the indices passed as arguments + if ( storage == 'c' || storage == 'C' ) + { + if ( trns == 'n' || trns == 'N' ) + m[i + j*ld] = exval; + else + m[j + i*ld] = exval; + } + else + { + if ( trns == 'n' || trns == 'N' ) + m[i*ld + j] = exval; + else + m[j*ld + i] = exval; + } +} -// Function template to set the exception value exval on matrix m, at indices (i, j) -// In case of transposition, this function internally swaps the indices, and thus they can be -// passed without swapping on the instantiator. +/* + Function to set few values of a matrix to values relative to DBL_MAX/DBL_MIN + These values are used to create overflow and underflow scenarios +*/ template -void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m ); +void set_overflow_underflow_mat(char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T* a, gtint_t mode, gtint_t input_range) +{ + /* Calculate index where overflow/underflow values need to be inserted */ + gtint_t indexA = 0; + + if ( storage == 'c' || storage == 'C' ) + { + if ( trns == 'n' || trns == 'N' ) + { + indexA = i + j*ld; + } + else + { + indexA = j + i*ld; + } + } + else + { + if ( trns == 'n' || trns == 'N' ) + { + indexA = i*ld + j; + } + else + { + indexA = j*ld + i; + } + } + + using RT = typename testinghelpers::type_info::real_type; + std::vector exponent(12); + + if (std::is_same::value) + { + exponent = {23, 203, 18, 180, 123, 130, 185, 178, 108, 158, 185, 220}; + } + else if (std::is_same::value) + { + exponent = {3, 20, 8, 2, 30, 28, 8, 10, 33, 24, 8, 22}; + } + + T limits_val; + + /* When mode is set to 0, values relative to DBL_MAX are inserted into the input matrices */ + if(mode == 0) + { + limits_val = (std::numeric_limits::max)(); + switch(input_range) + { + case -1: + a[0] = limits_val/ pow(10, exponent[0]); + a[indexA] = limits_val/ pow(10, exponent[1]); + break; + + case 0: + a[0] = -(limits_val/ pow(10, exponent[4])); + a[indexA] = -(limits_val/ pow(10, exponent[5])); + break; + + case 1: + a[0] = limits_val/ pow(10, exponent[8]); + a[indexA] = limits_val/ pow(10, exponent[9]); + } + } + /* When mode is set to 1, values relative to DBL_MIN are inserted into the input matrices*/ + else + { + limits_val = (std::numeric_limits::min)(); + switch(input_range) + { + case -1: + a[0] = limits_val * pow(10, exponent[0]); + a[indexA] = limits_val * pow(10, exponent[1]); + break; + + case 0: + a[0] = -(limits_val * pow(10, exponent[4])); + a[indexA] = -(limits_val * pow(10, exponent[5])); + break; + + case 1: + a[0] = limits_val * pow(10, exponent[8]); + a[indexA] = limits_val * pow(10, exponent[9]); + } + + } +} } //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/common/error_helpers.h b/gtestsuite/testinghelpers/inc/common/error_helpers.h index c61714d707..edd659e140 100644 --- a/gtestsuite/testinghelpers/inc/common/error_helpers.h +++ b/gtestsuite/testinghelpers/inc/common/error_helpers.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -49,8 +49,12 @@ namespace testinghelpers { template double getEpsilon() { +#ifdef THRESHOLD_ZERO + double eps = 0.0; +#else using RT = typename testinghelpers::type_info::real_type; double eps = std::numeric_limits::epsilon(); +#endif return eps; } diff --git a/gtestsuite/testinghelpers/inc/common/protected_buffer.h b/gtestsuite/testinghelpers/inc/common/protected_buffer.h new file mode 100644 index 0000000000..f66e2bf103 --- /dev/null +++ b/gtestsuite/testinghelpers/inc/common/protected_buffer.h @@ -0,0 +1,79 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#pragma once + +#include "common/type_info.h" + +namespace testinghelpers { + class ProtectedBuffer + { + private: + static const size_t REDZONE_SIZE = 1; + void* redzone_1 = nullptr; + void* redzone_2 = nullptr; + void* mem = nullptr; + bool is_mem_test = false; + + /** + * ========================================================================== + * get_mem + * returns a aligned or unaligned buffer of size "size" + * ========================================================================== + * @param[in] size specifies the size of the buffer to be allocated. + * @param[in] is_aligned specifies if the buffer needs to be aligned or not. + */ + static void* get_mem(dim_t, bool); + + public: + void* greenzone_1 = nullptr; + void* greenzone_2 = nullptr; + + ProtectedBuffer(dim_t size, bool is_aligned = false, bool is_mem_test = false); + ~ProtectedBuffer(); + + static void handle_mem_test_fail(int signal); + + /** + * Adds signal handler for segmentation fault. + */ + static void start_signal_handler(); + + /** + * Removes signal handler for segmentation fault. + */ + static void stop_signal_handler(); + }; +} diff --git a/gtestsuite/testinghelpers/inc/common/refCBLAS.h b/gtestsuite/testinghelpers/inc/common/refCBLAS.h index 0d64594117..d4355daf55 100644 --- a/gtestsuite/testinghelpers/inc/common/refCBLAS.h +++ b/gtestsuite/testinghelpers/inc/common/refCBLAS.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -41,7 +41,7 @@ #include /** - * This is a helper class that we use to load the symbols + * This is a helper class that we use to load the symbols * from the reference library dynamically so that we get * the reference solution. * Since dynamic loading can be time consuming this class works @@ -53,12 +53,12 @@ * loads the library either with a call to dlopen (Linux) or with * a call to LoadLibrary (Windows). * - Similarly the destructor unloads the library. - * - The member function loadSymbol() is used to return the pointer + * - The member function loadSymbol() is used to return the pointer * to that symbol in the library, either with a call to ldsym (Linux) * or with a call to GetProcAddress (Windows). * This means that the library is only loaded once per executable * due to having the global variable refCBLASModule and unloaded once - * at the end. Multiple calls to loadSymbol are used to access the + * at the end. Multiple calls to loadSymbol are used to access the * corresponding API used for reference. */ namespace testinghelpers { diff --git a/gtestsuite/testinghelpers/inc/common/testing_basics.h b/gtestsuite/testinghelpers/inc/common/testing_basics.h index e7f92a9356..a61168d650 100644 --- a/gtestsuite/testinghelpers/inc/common/testing_basics.h +++ b/gtestsuite/testinghelpers/inc/common/testing_basics.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -77,15 +77,16 @@ gtint_t matsize(char storage, char trans, gtint_t m, gtint_t n, gtint_t ldm ); /** * Returns the leading dimension of a matrix depending on the storage type, - * whether it is transpose or not, and the size of rows and columns. + * whether it is transpose or not, and the size of rows and columns, and the stride. * * @param storage specifies the storage format of matrix in memory. * @param trns specifies the form of given matrix. * @param m specifies the number of rows of given matrix. * @param n specifies the number of columns of given matrix. * @param inc specifies the increment of the leading dimension. + * @param stride specifies the stride between two "continuous" elements in the matrix. */ -gtint_t get_leading_dimension(char storage, char trans, gtint_t m, gtint_t n, gtint_t inc); +gtint_t get_leading_dimension( char storage, char trans, gtint_t m, gtint_t n, gtint_t inc, gtint_t stride = 1 ); /** * If T is real, returns NaN. @@ -94,6 +95,13 @@ gtint_t get_leading_dimension(char storage, char trans, gtint_t m, gtint_t n, gt template T getNaN(); +/** + * If T is real, returns NaN. + * If T is complex, returns {NaN, NaN} +*/ +template +T getNaNNaN(); + /** * If T is real, returns inf. * If T is complex, returns {inf, 0.0} @@ -101,6 +109,21 @@ T getNaN(); template T getInf(); +/** + * If T is real, returns inf. + * If T is complex, returns {inf, inf} +*/ +template +T getInfInf(); + +/** + * If T is real, returns extval. + * If T is complex, returns {extval, extval} + * where extval = NaN or Inf +*/ +template +T aocl_extreme(); + /** * @brief Returns the conjugate of a scalar x. * @@ -173,10 +196,28 @@ static void alphax( gtint_t n, T alpha, T *xp, gtint_t incx ) gtint_t ix = 0; for(i = 0 ; i < n ; i++) { xp[ix] = (alpha * xp[ix]); - ix = ix + incx; + // use absolute value of incx to ensure + // correctness when incx < 0 + ix = ix + std::abs(incx); } } +template +static T ONE() { + if constexpr (testinghelpers::type_info::is_real) + return 1.0; + else + return {1.0, 0.0}; +} + +template +static T ZERO() { + if constexpr (testinghelpers::type_info::is_real) + return 0.0; + else + return {0.0, 0.0}; +} + /** * @brief Returns the boolean form of a trans value. * @@ -342,42 +383,30 @@ void make_triangular( char storage, char uplo, gtint_t n, T* a, gtint_t ld ); template void make_diag( char storage, gtint_t m, gtint_t n, T alpha, T *a, gtint_t ld ); -/** - * print scalar value - * @param[in] x specifies the value. - * @param[in] spec specifies the format specifer. - */ -template -void print_scalar( T x, const char *spec ); - /** * print vector of length n - * @param[in] vec specifies the vector name * @param[in] n specifies the length of the given vector. * @param[in] a specifies pointer which points to the first element of a. * @param[in] incx specifies storage spacing between elements of a. - * @param[in] spec specifies the format specifer. */ template -void print_vector( const char *vec, gtint_t n, T *x, gtint_t incx, const char *spec ); +void print_vector( gtint_t n, T *x, gtint_t incx); /** * print matrix of size m x n - * @param[in] mat specifies the matrix name * @param[in] storage specifies the storage format of matrix in memory. * @param[in] m specifies the number of rows of given matrix. * @param[in] n specifies the number of columns of given matrix. * @param[in] a specifies pointer which points to the first element of a. * @param[in] ld specifies leading dimension for a given matrix. - * @param[in] spec specifies the format specifer. */ template -void print_matrix( const char *mat, char storage, gtint_t m, gtint_t n, T *a, gtint_t ld, const char *spec ); +void print_matrix( char storage, gtint_t m, gtint_t n, T *a, gtint_t ld); /** * @brief returns a string with the correct NaN/Inf for printing * - * @tparam T float, double, scomplex, dcomplex. + * @tparam T gtint_t, float, double, scomplex, dcomplex. * @param exval exception value for setting the string. */ template diff --git a/gtestsuite/testinghelpers/inc/common/testing_helpers.h b/gtestsuite/testinghelpers/inc/common/testing_helpers.h index 3720109148..408e91e252 100644 --- a/gtestsuite/testinghelpers/inc/common/testing_helpers.h +++ b/gtestsuite/testinghelpers/inc/common/testing_helpers.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -40,3 +40,4 @@ #include "data_generators.h" #include "error_helpers.h" #include "refCBLAS.h" +#include "protected_buffer.h" diff --git a/gtestsuite/testinghelpers/inc/common/type_info.h b/gtestsuite/testinghelpers/inc/common/type_info.h index 05cb0d1f76..2bf0eebec5 100644 --- a/gtestsuite/testinghelpers/inc/common/type_info.h +++ b/gtestsuite/testinghelpers/inc/common/type_info.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -47,8 +47,8 @@ namespace testinghelpers { // type_info::real_type will return the real type of T. - // If T is float or double, real_type is float or double respectivelly. - // If T is scomplex or dcomplex, real_type is float or double respectivelly. + // If T is float or double, real_type is float or double respectively. + // If T is scomplex or dcomplex, real_type is float or double respectively. template struct type_info { using real_type = T; diff --git a/gtestsuite/testinghelpers/inc/extension/ref_imatcopy.h b/gtestsuite/testinghelpers/inc/extension/ref_imatcopy.h new file mode 100644 index 0000000000..7699649638 --- /dev/null +++ b/gtestsuite/testinghelpers/inc/extension/ref_imatcopy.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * OMATCOPY performs vector operations + * A := alpha * op(A) + * where A is both the input and output matrix, and alpha is the scaling factor. + * op(A) could be one of the following operations : no-transpose('n'), transpose('t'), + * conjugate('c'), conjugate-transpose('r'). + * ========================================================================== +**/ + +namespace testinghelpers { + +template +void ref_imatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/extension/ref_omatcopy.h b/gtestsuite/testinghelpers/inc/extension/ref_omatcopy.h new file mode 100644 index 0000000000..132a6331c5 --- /dev/null +++ b/gtestsuite/testinghelpers/inc/extension/ref_omatcopy.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * OMATCOPY performs vector operations + * B := alpha * op(A) + * where A and B are input and output matrices, and alpha is the scaling factor. + * op(A) could be one of the following operations : no-transpose('n'), transpose('t'), + * conjugate('c'), conjugate-transpose('r'). + * ========================================================================== +**/ + +namespace testinghelpers { + +template +void ref_omatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/extension/ref_omatcopy2.h b/gtestsuite/testinghelpers/inc/extension/ref_omatcopy2.h new file mode 100644 index 0000000000..9860fba3c6 --- /dev/null +++ b/gtestsuite/testinghelpers/inc/extension/ref_omatcopy2.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * omatcopy2 performs vector operations + * B := alpha * op(A) + * where A and B are input and output matrices, and alpha is the scaling factor. + * op(A) could be one of the following operations : no-transpose('n'), transpose('t'), + * conjugate('c'), conjugate-transpose('r'). + * ========================================================================== +**/ + +namespace testinghelpers { + +template +void ref_omatcopy2( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, gtint_t stridea, T* B, gtint_t ldb, gtint_t strideb ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/level1/ref_addv.h b/gtestsuite/testinghelpers/inc/level1/ref_addv.h index c693369b90..756502a442 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_addv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_addv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_amaxv.h b/gtestsuite/testinghelpers/inc/level1/ref_amaxv.h index a4d2e7fe40..2b0cdf1f0a 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_amaxv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_amaxv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_axpbyv.h b/gtestsuite/testinghelpers/inc/level1/ref_axpbyv.h index 893583638d..a6e0972c39 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_axpbyv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_axpbyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_axpyf.h b/gtestsuite/testinghelpers/inc/level1/ref_axpyf.h new file mode 100644 index 0000000000..390c589164 --- /dev/null +++ b/gtestsuite/testinghelpers/inc/level1/ref_axpyf.h @@ -0,0 +1,64 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * AXPYV performs vector operations + * y := y + alpha * conjx(x) + * where x and y are vectors of length n, and alpha is a scalar + * ========================================================================== +**/ + +namespace testinghelpers { + +template +void ref_axpyf( char conja, + char conjx, + gtint_t m, + gtint_t b_n, + T *alpha, + T* a, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T* y, + gtint_t incy + ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/level1/ref_axpyv.h b/gtestsuite/testinghelpers/inc/level1/ref_axpyv.h index d0cbbbbf5f..9f380132e5 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_axpyv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_axpyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_copyv.h b/gtestsuite/testinghelpers/inc/level1/ref_copyv.h index 5342ea3526..a1a75fadc7 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_copyv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_copyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_dotv.h b/gtestsuite/testinghelpers/inc/level1/ref_dotv.h index 2b1f0b4a4d..a26fffd409 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_dotv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_dotv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_dotxf.h b/gtestsuite/testinghelpers/inc/level1/ref_dotxf.h new file mode 100644 index 0000000000..cd89f15c7d --- /dev/null +++ b/gtestsuite/testinghelpers/inc/level1/ref_dotxf.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +namespace testinghelpers { + +template +void ref_dotxf( char conja, + char conjx, + gtint_t m, + gtint_t b_n, + T *alpha, + T* a, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T *beta, + T* y, + gtint_t incy + ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/level1/ref_dotxv.h b/gtestsuite/testinghelpers/inc/level1/ref_dotxv.h index 8b662a05db..38d96e4595 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_dotxv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_dotxv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_scal2v.h b/gtestsuite/testinghelpers/inc/level1/ref_scal2v.h index 88a933d6f4..d116e90f3a 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_scal2v.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_scal2v.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level1/ref_scalv.h b/gtestsuite/testinghelpers/inc/level1/ref_scalv.h index 6e52878835..bfeebc6fd8 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_scalv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_scalv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -47,7 +47,7 @@ namespace testinghelpers { -template -void ref_scalv(char conjalpha, gtint_t len, T alpha, T* x, gtint_t incx); +template +void ref_scalv(char conjalpha, gtint_t len, U alpha, T* x, gtint_t incx); } //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/level1/ref_subv.h b/gtestsuite/testinghelpers/inc/level1/ref_subv.h index dd49b2571a..8755fade8c 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_subv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_subv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/frame/1m/packm/bli_packm_cxk_3mis.h b/gtestsuite/testinghelpers/inc/level1/ref_swapv.h similarity index 76% rename from frame/1m/packm/bli_packm_cxk_3mis.h rename to gtestsuite/testinghelpers/inc/level1/ref_swapv.h index 358cdcee4e..09ff315655 100644 --- a/frame/1m/packm/bli_packm_cxk_3mis.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_swapv.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,22 +32,22 @@ */ +#pragma once -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t is_p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_cxk_3mis ) +#include "common/testing_helpers.h" +/* + * ===================================== + * SWAPV performs a vector operation + * Swaps contents in x to y and y to x + * x <=> y + * where x & y is a vector of length n + * ===================================== +**/ + +namespace testinghelpers { + +template +void ref_swapv(gtint_t len, T* x, gtint_t incx, T* y, gtint_t incy); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/level1/ref_xpbyv.h b/gtestsuite/testinghelpers/inc/level1/ref_xpbyv.h index 92afc208ee..dbd6da346f 100644 --- a/gtestsuite/testinghelpers/inc/level1/ref_xpbyv.h +++ b/gtestsuite/testinghelpers/inc/level1/ref_xpbyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_gemv.h b/gtestsuite/testinghelpers/inc/level2/ref_gemv.h index 6f9a7c88de..a220333caf 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_gemv.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_gemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_ger.h b/gtestsuite/testinghelpers/inc/level2/ref_ger.h index d104c17659..b174b10f5d 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_ger.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_ger.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_hemv.h b/gtestsuite/testinghelpers/inc/level2/ref_hemv.h index 52100da1f6..f4e1c04dcc 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_hemv.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_hemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_her.h b/gtestsuite/testinghelpers/inc/level2/ref_her.h index 0c403f5e12..98a6f89cb3 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_her.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_her.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_her2.h b/gtestsuite/testinghelpers/inc/level2/ref_her2.h index ee56f84abb..48aa29ffbd 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_her2.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_her2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_symv.h b/gtestsuite/testinghelpers/inc/level2/ref_symv.h index 7d324e99cb..5fbbff62cb 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_symv.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_symv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_syr.h b/gtestsuite/testinghelpers/inc/level2/ref_syr.h index 3727ec1aa9..c5ed1f9cd4 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_syr.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_syr.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_syr2.h b/gtestsuite/testinghelpers/inc/level2/ref_syr2.h index 232171de28..58cac26690 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_syr2.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_syr2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_trmv.h b/gtestsuite/testinghelpers/inc/level2/ref_trmv.h index b7d8f1020f..71a22f0fa3 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_trmv.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_trmv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level2/ref_trsv.h b/gtestsuite/testinghelpers/inc/level2/ref_trsv.h index 268b7f381e..f3fa2e8445 100644 --- a/gtestsuite/testinghelpers/inc/level2/ref_trsv.h +++ b/gtestsuite/testinghelpers/inc/level2/ref_trsv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_gemm.h b/gtestsuite/testinghelpers/inc/level3/ref_gemm.h index 569726cdf9..af5c1451f2 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_gemm.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_gemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h b/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h index 283a2b06ec..a6b20ad1b8 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_gemm_compute.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_gemmt.h b/gtestsuite/testinghelpers/inc/level3/ref_gemmt.h index 6c2f58ca3f..14f795da68 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_gemmt.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_gemmt.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_hemm.h b/gtestsuite/testinghelpers/inc/level3/ref_hemm.h index 40d4178239..fa736a92ad 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_hemm.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_hemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_her2k.h b/gtestsuite/testinghelpers/inc/level3/ref_her2k.h index 3827625036..e2035febf5 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_her2k.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_her2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_herk.h b/gtestsuite/testinghelpers/inc/level3/ref_herk.h index ca29a1217d..b801ba78bd 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_herk.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_herk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_symm.h b/gtestsuite/testinghelpers/inc/level3/ref_symm.h index fef81db386..48d29780f3 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_symm.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_symm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_syr2k.h b/gtestsuite/testinghelpers/inc/level3/ref_syr2k.h index 4b170d70a8..4acf4a3bb3 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_syr2k.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_syr2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_syrk.h b/gtestsuite/testinghelpers/inc/level3/ref_syrk.h index 3d3b8765ae..89e17bfd11 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_syrk.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_syrk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_trmm.h b/gtestsuite/testinghelpers/inc/level3/ref_trmm.h index f75b2356bc..fb92e7d389 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_trmm.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_trmm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_trmm3.h b/gtestsuite/testinghelpers/inc/level3/ref_trmm3.h index 975238050a..6195d0f53a 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_trmm3.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_trmm3.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/level3/ref_trsm.h b/gtestsuite/testinghelpers/inc/level3/ref_trsm.h index df57786f69..47dfe0f934 100644 --- a/gtestsuite/testinghelpers/inc/level3/ref_trsm.h +++ b/gtestsuite/testinghelpers/inc/level3/ref_trsm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/inc/util/ref_asumv.h b/gtestsuite/testinghelpers/inc/util/ref_asumv.h new file mode 100644 index 0000000000..3c6ad26d3e --- /dev/null +++ b/gtestsuite/testinghelpers/inc/util/ref_asumv.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "common/testing_helpers.h" + +/* + * ========================================================================== + * ASUMV computes the sum of the absolute values of the fundamental elements + * of vector x. + * asum = |R(x1)| + |I(x1)| + |R(x2)| + |I(x2)| + ... + |R(xn)| + |I(xn)| + * where, + * x is a vector of size n, + * R(a) is the real component of the complex number a, + * I(a) is the imaginary component of the complex number a, + * |b| represents the absolute value of b. + * ========================================================================== +**/ + +namespace testinghelpers { + +template::real_type> +RT ref_asumv(gtint_t n, T* x, gtint_t incx); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/inc/util/ref_nrm2.h b/gtestsuite/testinghelpers/inc/util/ref_nrm2.h index 3163d46556..8e6a1bdadb 100644 --- a/gtestsuite/testinghelpers/inc/util/ref_nrm2.h +++ b/gtestsuite/testinghelpers/inc/util/ref_nrm2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/common/complex_helpers.cpp b/gtestsuite/testinghelpers/src/common/complex_helpers.cpp index 3f8b9a27fe..c5994f6b10 100644 --- a/gtestsuite/testinghelpers/src/common/complex_helpers.cpp +++ b/gtestsuite/testinghelpers/src/common/complex_helpers.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -87,6 +87,17 @@ dcomplex operator*(const dcomplex x, const dcomplex y) return dcomplex{(( x.real * y.real ) - ( x.imag * y.imag )),(( x.real * y.imag ) + ( x.imag * y.real ))}; } +scomplex operator/(const scomplex x, const scomplex y) +{ + return scomplex{(( x.real * y.real ) + ( x.imag * y.imag )) / (( y.real * y.real ) + ( y.imag * y.imag )), + (( x.imag * y.real ) - ( x.real * y.imag )) / (( y.real * y.real ) + ( y.imag * y.imag ))}; +} +dcomplex operator/(const dcomplex x, const dcomplex y) +{ + return dcomplex{(( x.real * y.real ) + ( x.imag * y.imag )) / (( y.real * y.real ) + ( y.imag * y.imag )), + (( x.imag * y.real ) - ( x.real * y.imag )) / (( y.real * y.real ) + ( y.imag * y.imag ))}; +} + bool operator== (const scomplex x, const scomplex y) { return ((x.real==y.real) && (x.imag==y.imag)); diff --git a/gtestsuite/testinghelpers/src/common/data_generators.cpp b/gtestsuite/testinghelpers/src/common/data_generators.cpp deleted file mode 100644 index 9edf5b5cc8..0000000000 --- a/gtestsuite/testinghelpers/src/common/data_generators.cpp +++ /dev/null @@ -1,530 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "common/testing_helpers.h" - -namespace testinghelpers { -namespace datagenerators { - -/*************************************************** - * Floating Point Generators -****************************************************/ -/** - * @brief Returns a random fp type (float, double, scomplex, dcomplex) - * that lies in the range [from, to]. - * - * @param[in, out] alpha the random fp - */ -template -void getfp(int from, int to, T* alpha) -{ - using real_T = typename testinghelpers::type_info::real_type; - std::mt19937 generator(94); - std::uniform_real_distribution distr(from, to); - if constexpr (testinghelpers::type_info::is_real) - *alpha = distr(generator); - else - *alpha = {distr(generator), distr(generator)}; -} - -/** - * @brief Returns a random fp vector (float, double, scomplex, dcomplex) - * with elements that follow a uniform distribution in the range [from, to]. - * @param[in] n length of vector x - * @param[in] incx increments of vector x - * @param[in, out] x the random fp vector - */ -template -void getfp(int from, int to, gtint_t n, gtint_t incx, T* x) -{ - using real_T = typename testinghelpers::type_info::real_type; - T* chi; - std::mt19937 generator(94); - std::uniform_real_distribution distr(from, to); - for ( gtint_t i = 0; i < n; ++i ) - { - chi = x + i*std::abs(incx); - if constexpr (testinghelpers::type_info::is_real) - *chi = distr(generator); - else - *chi = {distr(generator), distr(generator)}; - } -} - -template -void getfp(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda ) -{ - T* a_begin; - gtint_t inca; - gtint_t n_iter; - gtint_t n_elem; - gtint_t j; - - // Initialize with optimal values for column-major storage. - inca = 1; - n_iter = n; - n_elem = m; - - // An optimization: if A is row-major, then let's access the matrix by - // rows instead of by columns for increased spatial locality. - if( (storage == 'r') || (storage == 'R') ) - { - swap_dims( &n_iter, &n_elem ); - swap_dims( &lda, &inca ); - } - - for ( j = 0; j < n_iter; j++ ) - { - a_begin = a + j*lda; - getfp( from, to, n_elem, inca, a_begin ); - } -} - -template -void getfp(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda ) -{ - using real_T = typename testinghelpers::type_info::real_type; - std::mt19937 generator(1994); - std::uniform_real_distribution distr(from, to); - - if( chktrans( transa )) { - swap_dims( &m, &n ); - } - - if((storage == 'c') || (storage == 'C')) - { - for(gtint_t i=0; i::is_real) - a[i+j*lda] = real_T(distr(generator)); - else - a[i+j*lda] = {real_T(distr(generator)), real_T(distr(generator))}; - } - } - } - else if( (storage == 'r') || (storage == 'R') ) - { - for(gtint_t j=0; j::is_real) - a[j+i*lda] = real_T(distr(generator)); - else - a[j+i*lda] = {real_T(distr(generator)), real_T(distr(generator))}; - } - } - } -} - -/*************************************************** - * Integer Generators -****************************************************/ -/** - * @brief Returns a random integer converted to an fp type (float, double, scomplex, dcomplex) - * that lies in the range [from, to]. - * - * @param[in, out] alpha the random fp - */ -template -void getint(int from, int to, T* alpha) -{ - using real_T = typename testinghelpers::type_info::real_type; - std::mt19937 generator(94); - std::uniform_int_distribution distr(from, to); - if constexpr (testinghelpers::type_info::is_real) - *alpha = real_T(distr(generator)); - else - *alpha = {real_T(distr(generator)), real_T(distr(generator))}; -} -/** - * @brief Returns a random fp vector (float, double, scomplex, dcomplex) - * with elements that are integers and follow a uniform distribution in the range [from, to]. - * @param[in] n length of vector x - * @param[in] incx increments of vector x - * @param[in, out] x the random fp vector - */ -template -void getint(int from, int to, gtint_t n, gtint_t incx, T* x) -{ - using real_T = typename testinghelpers::type_info::real_type; - T* chi; - std::mt19937 generator(94); - std::uniform_int_distribution distr(from, to); - for ( gtint_t i = 0; i < n; ++i ) - { - chi = x + i*std::abs(incx); - if constexpr (testinghelpers::type_info::is_real) - *chi = real_T(distr(generator)); - else - *chi = {real_T(distr(generator)), real_T(distr(generator))}; - } -} - -template -void getint(int from, int to, char storage, gtint_t m, gtint_t n, T* a, gtint_t lda ) -{ - T* a_begin; - gtint_t inca; - gtint_t n_iter; - gtint_t n_elem; - gtint_t j; - - // Initialize with optimal values for column-major storage. - inca = 1; - n_iter = n; - n_elem = m; - - // An optimization: if A is row-major, then let's access the matrix by - // rows instead of by columns for increased spatial locality. - if( (storage == 'r') || (storage == 'R') ) - { - swap_dims( &n_iter, &n_elem ); - swap_dims( &lda, &inca ); - } - - for ( j = 0; j < n_iter; j++ ) - { - a_begin = a + j*lda; - getint( from, to, n_elem, inca, a_begin ); - } -} - -/// @brief -/// @tparam T -/// @param from -/// @param to -/// @param storage -/// @param m -/// @param n -/// @param a -/// @param transa -/// @param lda -template -void getint(int from, int to, char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda ) -{ - using real_T = typename testinghelpers::type_info::real_type; - std::mt19937 generator(1994); - std::uniform_int_distribution distr(from, to); - - if( chktrans( transa )) { - swap_dims( &m, &n ); - } - - if((storage == 'c') || (storage == 'C')) - { - for(gtint_t i=0; i::is_real) - a[i+j*lda] = real_T(distr(generator)); - else - a[i+j*lda] = {real_T(distr(generator)), real_T(distr(generator))}; - } - } - } - else if( (storage == 'r') || (storage == 'R') ) - { - for(gtint_t j=0; j::is_real) - a[j+i*lda] = real_T(distr(generator)); - else - a[j+i*lda] = {real_T(distr(generator)), real_T(distr(generator))}; - } - } - } -} - -template -void randomgenerators( int from, int to, T* alpha, char datatype ) { - - if( (datatype == 'i') ||(datatype == 'I') ) - getint( from, to, alpha ); - else /*if( (datatype == 'f') ||(datatype == 'F') ) */ - getfp( from, to, alpha ); -} - -template -void randomgenerators(int from, int to, gtint_t n, gtint_t incx, T* x, char datatype ) { - - if( (datatype == 'i') ||(datatype == 'I') ) - getint( from, to, n, incx, x ); - else /*if( (datatype == 'f') ||(datatype == 'F') ) */ - getfp( from, to, n, incx, x ); -} - -template -void randomgenerators( int from, int to, char storage, gtint_t m, gtint_t n, - T* a, gtint_t lda, char datatype ) { - - if( (datatype == 'i') ||(datatype == 'I') ) - getint( from, to, storage, m, n, a, lda ); - else /*if( (datatype == 'f') ||(datatype == 'F') ) */ - getfp( from, to, storage, m, n, a, lda ); -} - -template -void randomgenerators( int from, int to, char storage, gtint_t m, gtint_t n, - T* a, char transa, gtint_t lda, char datatype ) { - - if( (datatype == 'i') ||(datatype == 'I') ) - getint( from, to, storage, m, n, a, transa, lda ); - else /*if( (datatype == 'f') ||(datatype == 'F') ) */ - getfp( from, to, storage, m, n, a, transa, lda ); -} - -template -void randomgenerators(int from, int to, char storage, char uplo, gtint_t k, - T* a, gtint_t lda, char datatype ) { - randomgenerators(from, to, storage, k, k, a, lda, datatype); - if( (storage=='c')||(storage=='C') ) - { - for(gtint_t j=0; jj) a[i+j*lda] = T{0}; - } - else if ( (uplo=='l')||(uplo=='L') ) - { - if (ij) a[j+i*lda] = T{0}; - } - else if ( (uplo=='l')||(uplo=='L') ) - { - if (i -std::vector get_random_matrix(int from, int to, char storage, char trans, gtint_t m, gtint_t n, - gtint_t lda, char datatype ) -{ - std::vector a(matsize(storage, trans, m, n, lda)); - testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, a.data(), trans, lda, datatype ); - return a; -} - -template -std::vector get_random_matrix(int from, int to, char storage, char uplo, gtint_t k, gtint_t lda, char datatype ) -{ - // Create matrix for the given sizes. - std::vector a( testinghelpers::matsize( storage, 'n', k, k, lda ) ); - testinghelpers::datagenerators::randomgenerators( from, to, storage, uplo, k, a.data(), lda, datatype ); - return a; -} - -template -std::vector get_random_vector(int from, int to, gtint_t n, gtint_t incx, char datatype ) -{ - // Create vector for the given sizes. - std::vector x( testinghelpers::buff_dim(n, incx) ); - testinghelpers::datagenerators::randomgenerators( from, to, n, incx, x.data(), datatype ); - return x; -} - -template -void set_vector( gtint_t n, gtint_t incx, T* x, T value ) -{ - T* chi; - for ( gtint_t i = 0; i < n; ++i ) - { - chi = x + i*std::abs(incx); - *chi = value ; - } -} - -template -void set_matrix( char storage, gtint_t m, gtint_t n, T* a, char transa, gtint_t lda, T value ) -{ - if( chktrans( transa )) { - swap_dims( &m, &n ); - } - - if((storage == 'c') || (storage == 'C')) - { - for( gtint_t i = 0 ; i < m ; i++ ) - { - for( gtint_t j = 0 ; j < n ; j++ ) - { - a[i+j*lda] = value ; - } - } - } - else if( (storage == 'r') || (storage == 'R') ) - { - for( gtint_t j = 0 ; j < n ; j++ ) - { - for( gtint_t i = 0 ; i < m ; i++ ) - { - a[j+i*lda] = value ; - } - } - } -} - -template -std::vector get_vector( gtint_t n, gtint_t incx, T value ) -{ - // Create vector for the given sizes. - std::vector x( testinghelpers::buff_dim(n, incx) ); - testinghelpers::set_vector( n, incx, x.data(), value ); - return x; -} - -template -std::vector get_matrix( char storage, char trans, gtint_t m, gtint_t n, gtint_t lda, T value ) -{ - std::vector a( matsize( storage, trans, m, n, lda ) ); - testinghelpers::set_matrix( storage, m, n, a.data(), trans, lda, value ); - return a; -} - -template -void set_ev_mat( char storage, char trns, gtint_t ld, gtint_t i, gtint_t j, T exval, T* m ) -{ - // Setting the exception values on the indices passed as arguments - if ( storage == 'c' || storage == 'C' ) - { - if ( trns == 'n' || trns == 'N' ) - m[i + j*ld] = exval; - else - m[j + i*ld] = exval; - } - else - { - if ( trns == 'n' || trns == 'N' ) - m[i*ld + j] = exval; - else - m[j*ld + i] = exval; - } -} - -} //end of namespace testinghelpers - -// Explicit template instantiations -template void testinghelpers::datagenerators::randomgenerators(int, int, float*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, double*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, scomplex*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, dcomplex*, char); - -template void testinghelpers::datagenerators::randomgenerators(int, int, gtint_t, gtint_t, float*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, gtint_t, gtint_t, double*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, gtint_t, gtint_t, scomplex*, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, gtint_t, gtint_t, dcomplex*, char); - -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, float*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, double*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, scomplex*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, dcomplex*, gtint_t, char); - -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, float*, char, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, double*, char, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, scomplex*, char, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, gtint_t, gtint_t, dcomplex*, char, gtint_t, char); - -template void testinghelpers::datagenerators::randomgenerators(int, int, char, char, gtint_t, float*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, char, gtint_t, double*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, char, gtint_t, scomplex*, gtint_t, char); -template void testinghelpers::datagenerators::randomgenerators(int, int, char, char, gtint_t, dcomplex*, gtint_t, char); - -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, gtint_t, char); - -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_matrix(int, int, char, char, gtint_t, gtint_t, char); - -template std::vector testinghelpers::get_random_vector(int, int, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_vector(int, int, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_vector(int, int, gtint_t, gtint_t, char); -template std::vector testinghelpers::get_random_vector(int, int, gtint_t, gtint_t, char); - -template std::vector testinghelpers::get_vector(gtint_t, gtint_t, float); -template std::vector testinghelpers::get_vector(gtint_t, gtint_t, double); -template std::vector testinghelpers::get_vector(gtint_t, gtint_t, scomplex); -template std::vector testinghelpers::get_vector(gtint_t, gtint_t, dcomplex); - -template std::vector testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, float ); -template std::vector testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, double ); -template std::vector testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, scomplex ); -template std::vector testinghelpers::get_matrix( char, char, gtint_t, gtint_t, gtint_t, dcomplex ); - -template void testinghelpers::set_vector( gtint_t, gtint_t, float*, float ); -template void testinghelpers::set_vector( gtint_t, gtint_t, double*, double ); -template void testinghelpers::set_vector( gtint_t, gtint_t, scomplex*, scomplex ); -template void testinghelpers::set_vector( gtint_t, gtint_t, dcomplex*, dcomplex ); - -template void testinghelpers::set_matrix( char, gtint_t, gtint_t, float*, char, gtint_t, float ); -template void testinghelpers::set_matrix( char, gtint_t, gtint_t, double*, char, gtint_t, double ); -template void testinghelpers::set_matrix( char, gtint_t, gtint_t, scomplex*, char, gtint_t, scomplex ); -template void testinghelpers::set_matrix( char, gtint_t, gtint_t, dcomplex*, char, gtint_t, dcomplex ); - -template void testinghelpers::set_ev_mat( char, char, gtint_t, gtint_t, gtint_t, float, float* ); -template void testinghelpers::set_ev_mat( char, char, gtint_t, gtint_t, gtint_t, double, double* ); -template void testinghelpers::set_ev_mat( char, char, gtint_t, gtint_t, gtint_t, scomplex, scomplex* ); -template void testinghelpers::set_ev_mat( char, char, gtint_t, gtint_t, gtint_t, dcomplex, dcomplex* ); diff --git a/gtestsuite/testinghelpers/src/common/protected_buffer.cpp b/gtestsuite/testinghelpers/src/common/protected_buffer.cpp new file mode 100644 index 0000000000..093d7fb938 --- /dev/null +++ b/gtestsuite/testinghelpers/src/common/protected_buffer.cpp @@ -0,0 +1,190 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#if defined(__linux__) +#include +#include +#include +#include +#endif + +#include +#include "blis.h" +#include "common/protected_buffer.h" + +/* +* Returns aligned or unaligned memory of required size +*/ +void* testinghelpers::ProtectedBuffer::get_mem(dim_t size, bool is_aligned) +{ + void* mem = nullptr; +#if defined(__linux__) + mem = is_aligned ? aligned_alloc(BLIS_HEAP_STRIDE_ALIGN_SIZE, size) : malloc(size); +#else + mem = is_aligned ? _aligned_malloc(BLIS_HEAP_STRIDE_ALIGN_SIZE, size) : malloc(size); +#endif + if (mem == NULL) + { + printf("Protected Buffer: Memory not allocated.\n"); + exit(EXIT_FAILURE); + } + return mem; +} + +/** + * @brief Allocate memory for greenzones and redzones, and add protection to redzones + * + * @param size size of buffer required + * @param is_aligned should allocated memory be aligned + * @param is_mem_test is memory allocated for memory test. + */ +testinghelpers::ProtectedBuffer::ProtectedBuffer(dim_t size, bool is_aligned, bool is_mem_test) +{ +#if defined(__linux__) + this->is_mem_test = is_mem_test; + if (is_mem_test) + { + // query page size + size_t page_size = sysconf(_SC_PAGESIZE); + + // calculate minimum number of pages needed for requested size + // we make buffer at least twice the requested size to make sure + // that greenzone_1 and greenzone_2 do not overlap + size_t buffer_size = ((( size * 2 ) / page_size) + 1) * page_size; + + // allocate memory (buffer_size + 1 page to ensure 1st redzone can be started at page bounday + // + 2 * REDZONE_SIZE pages for 1 redzone on each end of buffer) + mem = (char*)get_mem(buffer_size + ((1 + (REDZONE_SIZE * 2)) * page_size), is_aligned); + + // set redzone_1 to mem+page_size to make sure that + // atleast one page boundary exist between mem and redzone_1 + redzone_1 = (void*)((char*)mem + page_size); + + // find page boundary ( address which is multiple of pagesize and less than redzone_1 ) + // say page_size is Nth power of 2 therefore only (N+1)th LSB is set in page_size + // (-page_size) implies 2's complement therefore in (-page_size) N LSBs are unset, all + // other bits are set. + // (redzone_1 & -page_size) will unset N LSBs of redzone_1, therefore making redzone_1 a + // multiple of page_size. + // this line is equivalent to (redzone_1 - (redzone_1 % page_size)) + // where page_size is power of two. + redzone_1 = (void*)((uintptr_t)(redzone_1) & -page_size); + + // redzone_2 = redzone_1 + sizeof redzone_1 + sizeof buffer + redzone_2 = (void*)((char*)redzone_1 + (page_size * REDZONE_SIZE) + buffer_size); + + // make redzones read/write/execute protected + int res = mprotect(redzone_1, page_size * REDZONE_SIZE, PROT_NONE); + if (res == -1) + { + do { perror("mprotect"); exit(EXIT_FAILURE); } while (0); + } + res = mprotect(redzone_2, page_size * REDZONE_SIZE, PROT_NONE); + if (res == -1) + { + do { perror("mprotect"); exit(EXIT_FAILURE); } while (0); + } + + // get address to the first "size" bytes of buffer + greenzone_1 = (void*)((char*)redzone_1 + (page_size * REDZONE_SIZE)); + + // get address to the last "size" bytes of buffer + greenzone_2 = (void*)((char*)redzone_2 - size); + } + else +#endif + { + mem = get_mem(size, is_aligned); + greenzone_1 = mem, greenzone_2 = mem; + } + +} + +/** + * @brief Remove Protection from redzones and free allocated memory + */ +testinghelpers::ProtectedBuffer::~ProtectedBuffer() +{ +#if defined(__linux__) + if(is_mem_test) + { + size_t page_size = sysconf(_SC_PAGESIZE); + + int res = mprotect(redzone_1, page_size * REDZONE_SIZE, PROT_READ | PROT_WRITE ); + if (res == -1) + { + do { perror("mprotect"); exit(EXIT_FAILURE); } while (0); + } + res = mprotect(redzone_2, page_size * REDZONE_SIZE, PROT_READ | PROT_WRITE ); + if (res == -1) + { + do { perror("mprotect"); exit(EXIT_FAILURE); } while (0); + } + } +#endif + free(mem); +} + +/** + * Function to handle segfault during memory test and convert it to a exception + */ +void testinghelpers::ProtectedBuffer::handle_mem_test_fail(int signal) +{ +#if defined(__linux__) + // unmask the segmentation fault signal + sigset_t signal_set; + sigemptyset(&signal_set); + sigaddset(&signal_set, SIGSEGV); + sigprocmask(SIG_UNBLOCK, &signal_set, NULL); + + throw std::out_of_range("err invalid"); +#endif +} + +void testinghelpers::ProtectedBuffer::start_signal_handler() +{ +#if defined(__linux__) + // add signal handler for segmentation fault + signal(SIGSEGV, ProtectedBuffer::handle_mem_test_fail); +#endif +} + + +void testinghelpers::ProtectedBuffer::stop_signal_handler() +{ +#if defined(__linux__) + // reset to default signal handler + signal(SIGSEGV, SIG_DFL); +#endif +} diff --git a/gtestsuite/testinghelpers/src/common/refCBLAS.cpp b/gtestsuite/testinghelpers/src/common/refCBLAS.cpp index 12499648e1..0aaf0cdd98 100644 --- a/gtestsuite/testinghelpers/src/common/refCBLAS.cpp +++ b/gtestsuite/testinghelpers/src/common/refCBLAS.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/common/testing_basics.cpp b/gtestsuite/testinghelpers/src/common/testing_basics.cpp index 5deec8e5a4..51efb31fdd 100644 --- a/gtestsuite/testinghelpers/src/common/testing_basics.cpp +++ b/gtestsuite/testinghelpers/src/common/testing_basics.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -132,44 +132,52 @@ gtint_t buff_dim( gtint_t n, gtint_t incx ) { gtint_t matsize( char storage, char trans, gtint_t m, gtint_t n, gtint_t ldm ) { - gtint_t km; + gtint_t km, lm; if( (storage == 'c') || (storage == 'C') ) { /*Column_Major*/ km = chktrans( trans ) ? m : n ; + lm = chktrans( trans ) ? n : m ; } else { /*Row_Major*/ km = chktrans( trans ) ? n : m ; + lm = chktrans( trans ) ? m : n ; } - return (km*ldm); + if ( ldm <= 0 || ldm < lm ) + return 0; + else + return (km*ldm); } /** * Returns the leading dimension of a matrix depending on the storage type, - * whether it is transpose or not, and the size of rows and columns. + * whether it is transpose or not, and the size of rows and columns, and the stride. * * @param storage specifies the storage format of matrix in memory. * @param trns specifies the form of given matrix. * @param m specifies the number of rows of given matrix. * @param n specifies the number of columns of given matrix. * @param inc specifies the increment of the leading dimension. + * @param stride specifies the stride between two "continuous" elements in the matrix. */ -gtint_t get_leading_dimension( char storage, char trans, gtint_t m, gtint_t n, gtint_t inc ) +gtint_t get_leading_dimension( char storage, char trans, gtint_t m, gtint_t n, gtint_t inc, gtint_t stride ) { gtint_t lda; + gtint_t m_max = (std::max)(gtint_t(1),m); + gtint_t n_max = (std::max)(gtint_t(1),n); if( (storage == 'c') || (storage == 'C') ) //column-major order { if ((trans == 'n')||(trans == 'N')) - lda = (std::max)(gtint_t(1),m) + inc; + lda = ( ( m_max - 1 ) * stride + 1 ) + inc; else - lda = (std::max)(gtint_t(1),n) + inc; + lda = ( ( n_max - 1 ) * stride + 1 ) + inc; } else //row-major order { if ((trans == 'n')||(trans == 'N')) - lda = (std::max)(gtint_t(1),n) + inc; + lda = ( ( n_max - 1 ) * stride + 1 ) + inc; else - lda = (std::max)(gtint_t(1),m) + inc; + lda = ( ( m_max - 1 ) * stride + 1 ) + inc; } return lda; } @@ -192,6 +200,24 @@ template double getNaN(); template scomplex getNaN(); template dcomplex getNaN(); +/** + * If T is real, returns NaN. + * If T is complex, returns {NaN, NaN} +*/ +template +T getNaNNaN() +{ + using RT = typename testinghelpers::type_info::real_type; + if constexpr (testinghelpers::type_info::is_real) + return std::numeric_limits::quiet_NaN(); + else + return T{std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; +} +template float getNaNNaN(); +template double getNaNNaN(); +template scomplex getNaNNaN(); +template dcomplex getNaNNaN(); + /** * If T is real, returns inf. * If T is complex, returns {inf, 0.0} @@ -210,11 +236,49 @@ template double getInf(); template scomplex getInf(); template dcomplex getInf(); +/** + * If T is real, returns inf. + * If T is complex, returns {inf, inf} +*/ +template +T getInfInf() +{ + using RT = typename testinghelpers::type_info::real_type; + if constexpr (testinghelpers::type_info::is_real) + return std::numeric_limits::infinity(); + else + return T{std::numeric_limits::infinity(), std::numeric_limits::infinity()}; +} +template float getInfInf(); +template double getInfInf(); +template scomplex getInfInf(); +template dcomplex getInfInf(); + +/** + * If T is real, returns extval. + * If T is complex, returns {extval, extval} + * where extval = NaN or Inf +*/ +template +T aocl_extreme() +{ +#if EXT_VAL == NaN + return getNaNNaN(); +#else + return getInfInf(); +#endif +} +template float aocl_extreme(); +template double aocl_extreme(); +template scomplex aocl_extreme(); +template dcomplex aocl_extreme(); bool chktrans( char trns ) { - return (!(trns=='n')); + trans_t trans; + char_to_blis_trans( trns, &trans ); + return ( bool ) !( trans == BLIS_NO_TRANSPOSE ); } bool chknotrans( char trns ) @@ -527,68 +591,39 @@ template void make_diag( char, gtint_t, gtint_t, double, double *, gtint template void make_diag( char, gtint_t, gtint_t, scomplex, scomplex *, gtint_t ); template void make_diag( char, gtint_t, gtint_t, dcomplex, dcomplex *, gtint_t ); -/** - * print scalar value - * @param[in] x specifies the value. - * @param[in] spec specifies the format specifer. - */ -template -void print_scalar( T x, const char *spec ) { - if constexpr (testinghelpers::type_info::is_real) - printf(spec, x); - else { - printf( spec, x.real ); - if(x.imag < 0) printf( "-" ); - else printf( "+" ); - printf( spec, abs(x.imag) ); - printf( " " ); - } -} -template void print_scalar( float x, const char * ); -template void print_scalar( double x, const char * ); -template void print_scalar( scomplex x, const char * ); -template void print_scalar( dcomplex x, const char * ); - /** * print vector of length n - * @param[in] vec specifies the vector name * @param[in] n specifies the length of the given vector. * @param[in] a specifies pointer which points to the first element of a. * @param[in] incx specifies storage spacing between elements of a. - * @param[in] spec specifies the format specifer. */ template -void print_vector( const char *vec, gtint_t n, T *x, gtint_t incx, const char *spec ) +void print_vector( gtint_t n, T *x, gtint_t incx) { gtint_t i, idx; T val; - std::cout << "Vector " << vec << std::endl; for ( i = 0; i < n; i++ ) { idx = (incx > 0) ? (i * incx) : ( - ( n - i - 1 ) * incx ); val = x[idx]; - print_scalar(val,spec); - printf( " " ); + std::cout<( const char *vec, gtint_t, float *, gtint_t, const char * ); -template void print_vector( const char *vec, gtint_t, double *, gtint_t, const char * ); -template void print_vector( const char *vec, gtint_t, scomplex *, gtint_t, const char * ); -template void print_vector( const char *vec, gtint_t, dcomplex *, gtint_t, const char * ); +template void print_vector( gtint_t, float *, gtint_t); +template void print_vector( gtint_t, double *, gtint_t); +template void print_vector( gtint_t, scomplex *, gtint_t); +template void print_vector( gtint_t, dcomplex *, gtint_t); /** * print matrix of size m x n - * @param[in] mat specifies the matrix name * @param[in] storage specifies the storage format of matrix in memory. * @param[in] m specifies the number of rows of given matrix. * @param[in] n specifies the number of columns of given matrix. * @param[in] a specifies pointer which points to the first element of a. * @param[in] ld specifies leading dimension for a given matrix. - * @param[in] spec specifies the format specifer. */ template -void print_matrix( const char *mat, char storage, gtint_t m, gtint_t n, T *a, gtint_t ld, const char *spec ) +void print_matrix( char storage, gtint_t m, gtint_t n, T *a, gtint_t ld) { gtint_t rs,cs; rs=cs=1; @@ -599,25 +634,20 @@ void print_matrix( const char *mat, char storage, gtint_t m, gtint_t n, T *a, gt rs = ld ; gtint_t i, j; - std::cout << "Matrix " << mat << std::endl; for ( i = 0; i < m; i++ ) { for ( j = 0; j < n; j++ ) { val = a[i*rs + j*cs]; - print_scalar(val,spec); - printf( " " ); + std::cout<( const char *mat, char, gtint_t, gtint_t, float *, gtint_t, const char * ); -template void print_matrix( const char *mat, char, gtint_t, gtint_t, double *, gtint_t, const char * ); -template void print_matrix( const char *mat, char, gtint_t, gtint_t, scomplex *, gtint_t, const char * ); -template void print_matrix( const char *mat, char, gtint_t, gtint_t, dcomplex *, gtint_t, const char * ); - - +template void print_matrix( char, gtint_t, gtint_t, float *, gtint_t); +template void print_matrix( char, gtint_t, gtint_t, double *, gtint_t); +template void print_matrix( char, gtint_t, gtint_t, scomplex *, gtint_t); +template void print_matrix( char, gtint_t, gtint_t, dcomplex *, gtint_t); /* Helper function that returns a string based on the value that is passed The return values are as follows : @@ -627,66 +657,46 @@ template void print_matrix( const char *mat, char, gtint_t, gtint_t, d If the datatype is complex : The string is concatenated with both the real and imaginary components values, based on analysis done separately to each of them (similar to real datatype). + + Also handles values of datatype gtint_t. */ template std::string get_value_string(T exval) { std::string exval_str; - if constexpr (testinghelpers::type_info::is_real) + if constexpr (std::is_integral::value) + { + exval_str = ( exval >= 0) ? std::to_string(exval) : "m" + std::to_string(std::abs(exval)); + } + else if constexpr (testinghelpers::type_info::is_real) { if(std::isnan(exval)) exval_str = "nan"; else if(std::isinf(exval)) - exval_str = (exval >= 0) ? "inf" : "minus_inf"; + exval_str = (exval >= testinghelpers::ZERO()) ? "inf" : "minus_inf"; else - exval_str = ( exval >= 0) ? std::to_string(int(exval)) : "minus_" + std::to_string(int(std::abs(exval))); - } - else - { - if(std::isnan(exval.real)) - { - exval_str = "nan"; - if(std::isinf(exval.imag)) - exval_str = exval_str + "pi" + ((exval.imag >= 0) ? "inf" : "minus_inf"); - else - exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag)))); - } - else if(std::isnan(exval.imag)) - { - if(std::isinf(exval.real)) - exval_str = ((exval.real >= 0) ? "inf" : "minus_inf"); - else - exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real)))); - exval_str = exval_str + "pinan"; - } - else if(std::isinf(exval.real)) - { - exval_str = ((exval.real >= 0) ? "inf" : "minus_inf"); - if(std::isnan(exval.imag)) - exval_str = exval_str + "pinan"; - else - exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag)))); - } - else if(std::isinf(exval.imag)) { - if(std::isnan(exval.real)) - exval_str = "nan"; + // Handle -0.0 separately + if (exval == -testinghelpers::ZERO()) + exval_str = "m" + std::to_string(std::abs(exval)); else - exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real)))); - - exval_str = exval_str + ((exval.imag >= 0) ? "inf" : "minus_inf"); - } - else - { - exval_str = ((exval.real >= 0)? std::to_string(int(exval.real)) : "m" + std::to_string(int(std::abs(exval.real)))); - exval_str = exval_str + "pi" + ((exval.imag >= 0)? std::to_string(int(exval.imag)) : "m" + std::to_string(int(std::abs(exval.imag)))); + exval_str = ( exval >= testinghelpers::ZERO()) ? std::to_string(exval) : "m" + std::to_string(std::abs(exval)); + exval_str = exval_str.substr(0, exval_str.find(".")+2); + exval_str = exval_str.replace(exval_str.find("."),1,"p"); } } + else if constexpr (testinghelpers::type_info::is_complex) + { + using RT = typename testinghelpers::type_info::real_type; + exval_str = get_value_string(exval.real) + std::string{"_"} + get_value_string(exval.imag) + std::string{"i"}; + } + return exval_str; } template std::string testinghelpers::get_value_string( float ); template std::string testinghelpers::get_value_string( double ); template std::string testinghelpers::get_value_string( scomplex ); template std::string testinghelpers::get_value_string( dcomplex ); +template std::string testinghelpers::get_value_string( gtint_t ); } //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/extension/ref_imatcopy.cpp b/gtestsuite/testinghelpers/src/extension/ref_imatcopy.cpp new file mode 100644 index 0000000000..0b3f69cbca --- /dev/null +++ b/gtestsuite/testinghelpers/src/extension/ref_imatcopy.cpp @@ -0,0 +1,232 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "extension/ref_imatcopy.h" + +namespace testinghelpers { + +#if defined(REF_IS_OPENBLAS) + +// Template function to load and call CBLAS call of OpenBLAS ?imatcopy, only for real datatypes +template +void ref_imatcopy_real( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ) { + + // Since CBLAS call does not support plain conjugation, we need to conjugate A + // in case trans == 'r'(only conjugation) + if( trans == 'r' ) + { + gtint_t size_a = testinghelpers::matsize(storage, 'n', m, n, lda_in ); + std::vector A_conj( size_a ); + memcpy( A_conj.data(), A, size_a * sizeof(T) ); + testinghelpers::conj( storage, A_conj.data(), m, n, lda_in ); + memcpy( A, A_conj.data(), size_a * sizeof(T) ); + trans = 'n'; + } + + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_trans; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trans, &cblas_trans ); + + // Defining the function pointer type for CBLAS call of imatcopy + typedef void (*Fptr_ref_cblas_imatcopy)( + const CBLAS_ORDER, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const T, + const T *, const f77_int, const f77_int + ); + + // Function pointer to load the CBLAS symbol + Fptr_ref_cblas_imatcopy ref_cblas_imatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_imatcopy = (Fptr_ref_cblas_imatcopy)refCBLASModule.loadSymbol("cblas_simatcopy"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_imatcopy = (Fptr_ref_cblas_imatcopy)refCBLASModule.loadSymbol("cblas_dimatcopy"); + } + + if (!ref_cblas_imatcopy) { + throw std::runtime_error("Error in ref_imatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_cblas_imatcopy( cblas_order, cblas_trans, m, n, alpha, A, lda_in, lda_out ); +} + +// Template function to load and call CBLAS call of OpenBLAS ?imatcopy, only for complex datatypes +template +void ref_imatcopy_complex( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ) { + + // Since CBLAS call does not support plain conjugation, we need to conjugate A + // in case trans == 'r'(only conjugation) + if( trans == 'r' ) + { + gtint_t size_a = testinghelpers::matsize(storage, 'n', m, n, lda_in ); + std::vector A_conj( size_a ); + memcpy( A_conj.data(), A, size_a * sizeof(T) ); + testinghelpers::conj( storage, A_conj.data(), m, n, lda_in ); + memcpy( A, A_conj.data(), size_a * sizeof(T) ); + trans = 'n'; + } + + // Getting the real-precision of the complex datatype + using RT = typename testinghelpers::type_info::real_type; + + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_trans; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trans, &cblas_trans ); + + // Defining the function pointer type for CBLAS call of imatcopy + typedef void (*Fptr_ref_cblas_imatcopy)( + const CBLAS_ORDER, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const RT *, + const RT *, const f77_int, const f77_int + ); + + // Function pointer to load the CBLAS symbol + Fptr_ref_cblas_imatcopy ref_cblas_imatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(scomplex)) + { + ref_cblas_imatcopy = (Fptr_ref_cblas_imatcopy)refCBLASModule.loadSymbol("cblas_cimatcopy"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_cblas_imatcopy = (Fptr_ref_cblas_imatcopy)refCBLASModule.loadSymbol("cblas_zimatcopy"); + } + + if (!ref_cblas_imatcopy) { + throw std::runtime_error("Error in ref_imatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_cblas_imatcopy( cblas_order, cblas_trans, m, n, (RT *)(&alpha), (RT *)A, lda_in, lda_out ); +} + +template +void ref_imatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ) { + + // Due to difference in the CBLAS API signature for OpenBLAS ?imatcopy(among real and complex) + // types, we have two different template functions(front-ends), that will be called based on the + // datatype. + if ((typeid(T) == typeid(float)) || (typeid(T) == typeid(double))) + { + ref_imatcopy_real( storage, trans, m, n, alpha, A, lda_in, lda_out ); + } + else if ((typeid(T) == typeid(scomplex)) || (typeid(T) == typeid(dcomplex))) + { + ref_imatcopy_complex( storage, trans, m, n, alpha, A, lda_in, lda_out ); + } + else + { + throw std::runtime_error("Error in ref_imatcopy.cpp: Invalid typename is passed function template."); + } +} + +#elif defined(REF_IS_MKL) +template +void ref_imatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ) { + + // Defining the function pointer type for the native MKL call of imatcopy + typedef void (*Fptr_ref_mkl_imatcopy)( + char, char, size_t, size_t, + const T, const T *, size_t, + size_t + ); + + // Function pointer to load the MKL symbol + Fptr_ref_mkl_imatcopy ref_mkl_imatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_mkl_imatcopy = (Fptr_ref_mkl_imatcopy)refCBLASModule.loadSymbol("MKL_Simatcopy"); + } + else if (typeid(T) == typeid(double)) + { + ref_mkl_imatcopy = (Fptr_ref_mkl_imatcopy)refCBLASModule.loadSymbol("MKL_Dimatcopy"); + } + else if (typeid(T) == typeid(scomplex)) + { + ref_mkl_imatcopy = (Fptr_ref_mkl_imatcopy)refCBLASModule.loadSymbol("MKL_Cimatcopy"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_mkl_imatcopy = (Fptr_ref_mkl_imatcopy)refCBLASModule.loadSymbol("MKL_Zimatcopy"); + } + else + { + throw std::runtime_error("Error in ref_imatcopy.cpp: Invalid typename is passed function template."); + } + if (!ref_mkl_imatcopy) { + throw std::runtime_error("Error in ref_imatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_mkl_imatcopy( storage, trans, m, n, alpha, A, lda_in, lda_out ); +} +#else +template +void ref_imatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda_in, gtint_t lda_out ) { + throw std::runtime_error("Error in ref_imatcopy.cpp: The provided reference does not support the required operation."); +} +#endif + +// Explicit template instantiations +#if defined(REF_IS_OPENBLAS) +template void ref_imatcopy_real( char, char, gtint_t, gtint_t, float, float*, gtint_t, gtint_t ); +template void ref_imatcopy_real( char, char, gtint_t, gtint_t, double, double*, gtint_t, gtint_t ); +template void ref_imatcopy_complex( char, char, gtint_t, gtint_t, scomplex, scomplex*, gtint_t, gtint_t ); +template void ref_imatcopy_complex( char, char, gtint_t, gtint_t, dcomplex, dcomplex*, gtint_t, gtint_t ); +#endif + +template void ref_imatcopy( char, char, gtint_t, gtint_t, float, float*, gtint_t, gtint_t ); +template void ref_imatcopy( char, char, gtint_t, gtint_t, double, double*, gtint_t, gtint_t ); +template void ref_imatcopy( char, char, gtint_t, gtint_t, scomplex, scomplex*, gtint_t, gtint_t ); +template void ref_imatcopy( char, char, gtint_t, gtint_t, dcomplex, dcomplex*, gtint_t, gtint_t ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/extension/ref_omatcopy.cpp b/gtestsuite/testinghelpers/src/extension/ref_omatcopy.cpp new file mode 100644 index 0000000000..a1c72903fc --- /dev/null +++ b/gtestsuite/testinghelpers/src/extension/ref_omatcopy.cpp @@ -0,0 +1,234 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "extension/ref_omatcopy.h" + +namespace testinghelpers { + +#if defined(REF_IS_OPENBLAS) + +// Template function to load and call CBLAS call of OpenBLAS ?omatcopy, only for real datatypes +template +void ref_omatcopy_real( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ) { + + // Since CBLAS call does not support plain conjugation, we need to conjugate A + // in case trans == 'r'(only conjugation) + if( trans == 'r' ) + { + gtint_t size_a = testinghelpers::matsize(storage, 'n', m, n, lda); + std::vector A_conj( size_a ); + memcpy( A_conj.data(), A, size_a * sizeof(T) ); + testinghelpers::conj( storage, A_conj.data(), m, n, lda ); + memcpy( A, A_conj.data(), size_a * sizeof(T) ); + trans = 'n'; + } + + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_trans; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trans, &cblas_trans ); + + // Defining the function pointer type for CBLAS call of OMATCOPY + typedef void (*Fptr_ref_cblas_omatcopy)( + const CBLAS_ORDER, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const T, + const T *, const f77_int, const T *, + const f77_int + ); + + // Function pointer to load the CBLAS symbol + Fptr_ref_cblas_omatcopy ref_cblas_omatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_omatcopy = (Fptr_ref_cblas_omatcopy)refCBLASModule.loadSymbol("cblas_somatcopy"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_omatcopy = (Fptr_ref_cblas_omatcopy)refCBLASModule.loadSymbol("cblas_domatcopy"); + } + + if (!ref_cblas_omatcopy) { + throw std::runtime_error("Error in ref_omatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_cblas_omatcopy( cblas_order, cblas_trans, m, n, alpha, A, lda, B, ldb ); +} + +// Template function to load and call CBLAS call of OpenBLAS ?omatcopy, only for complex datatypes +template +void ref_omatcopy_complex( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ) { + + // Since CBLAS call does not support plain conjugation, we need to conjugate A + // in case trans == 'r'(only conjugation) + if( trans == 'r' ) + { + gtint_t size_a = testinghelpers::matsize(storage, 'n', m, n, lda); + std::vector A_conj( size_a ); + memcpy( A_conj.data(), A, size_a * sizeof(T) ); + testinghelpers::conj( storage, A_conj.data(), m, n, lda ); + memcpy( A, A_conj.data(), size_a * sizeof(T) ); + trans = 'n'; + } + + // Getting the real-precision of the complex datatype + using RT = typename testinghelpers::type_info::real_type; + + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_trans; + + char_to_cblas_order( storage, &cblas_order ); + char_to_cblas_trans( trans, &cblas_trans ); + + // Defining the function pointer type for CBLAS call of OMATCOPY + typedef void (*Fptr_ref_cblas_omatcopy)( + const CBLAS_ORDER, const CBLAS_TRANSPOSE, + const f77_int, const f77_int, const RT *, + const RT *, const f77_int, const RT *, + const f77_int + ); + + // Function pointer to load the CBLAS symbol + Fptr_ref_cblas_omatcopy ref_cblas_omatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(scomplex)) + { + ref_cblas_omatcopy = (Fptr_ref_cblas_omatcopy)refCBLASModule.loadSymbol("cblas_comatcopy"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_cblas_omatcopy = (Fptr_ref_cblas_omatcopy)refCBLASModule.loadSymbol("cblas_zomatcopy"); + } + + if (!ref_cblas_omatcopy) { + throw std::runtime_error("Error in ref_omatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_cblas_omatcopy( cblas_order, cblas_trans, m, n, (RT *)(&alpha), (RT *)A, lda, (RT *)B, ldb ); +} + +template +void ref_omatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ) { + + // Due to difference in the CBLAS API signature for OpenBLAS ?omatcopy(among real and complex) + // types, we have two different template functions(front-ends), that will be called based on the + // datatype. + if ((typeid(T) == typeid(float)) || (typeid(T) == typeid(double))) + { + ref_omatcopy_real( storage, trans, m, n, alpha, A, lda, B, ldb ); + } + else if ((typeid(T) == typeid(scomplex)) || (typeid(T) == typeid(dcomplex))) + { + ref_omatcopy_complex( storage, trans, m, n, alpha, A, lda, B, ldb ); + } + else + { + throw std::runtime_error("Error in ref_omatcopy.cpp: Invalid typename is passed function template."); + } +} + +#elif defined(REF_IS_MKL) +template +void ref_omatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ) { + + // Defining the function pointer type for the native MKL call of OMATCOPY + typedef void (*Fptr_ref_mkl_omatcopy)( + char, char, size_t, size_t, + const T, const T *, size_t, + T *, size_t + ); + + // Function pointer to load the MKL symbol + Fptr_ref_mkl_omatcopy ref_mkl_omatcopy = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_mkl_omatcopy = (Fptr_ref_mkl_omatcopy)refCBLASModule.loadSymbol("MKL_Somatcopy"); + } + else if (typeid(T) == typeid(double)) + { + ref_mkl_omatcopy = (Fptr_ref_mkl_omatcopy)refCBLASModule.loadSymbol("MKL_Domatcopy"); + } + else if (typeid(T) == typeid(scomplex)) + { + ref_mkl_omatcopy = (Fptr_ref_mkl_omatcopy)refCBLASModule.loadSymbol("MKL_Comatcopy"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_mkl_omatcopy = (Fptr_ref_mkl_omatcopy)refCBLASModule.loadSymbol("MKL_Zomatcopy"); + } + else + { + throw std::runtime_error("Error in ref_omatcopy.cpp: Invalid typename is passed function template."); + } + if (!ref_mkl_omatcopy) { + throw std::runtime_error("Error in ref_omatcopy.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_mkl_omatcopy( storage, trans, m, n, alpha, A, lda, B, ldb ); +} +#else +template +void ref_omatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, T* B, gtint_t ldb ) { + throw std::runtime_error("Error in ref_omatcopy.cpp: The provided reference does not support the required operation."); +} +#endif + +// Explicit template instantiations +#if defined(REF_IS_OPENBLAS) +template void ref_omatcopy_real( char, char, gtint_t, gtint_t, float, float*, gtint_t, float*, gtint_t ); +template void ref_omatcopy_real( char, char, gtint_t, gtint_t, double, double*, gtint_t, double*, gtint_t ); +template void ref_omatcopy_complex( char, char, gtint_t, gtint_t, scomplex, scomplex*, gtint_t, scomplex*, gtint_t ); +template void ref_omatcopy_complex( char, char, gtint_t, gtint_t, dcomplex, dcomplex*, gtint_t, dcomplex*, gtint_t ); +#endif + +template void ref_omatcopy( char, char, gtint_t, gtint_t, float, float*, gtint_t, float*, gtint_t ); +template void ref_omatcopy( char, char, gtint_t, gtint_t, double, double*, gtint_t, double*, gtint_t ); +template void ref_omatcopy( char, char, gtint_t, gtint_t, scomplex, scomplex*, gtint_t, scomplex*, gtint_t ); +template void ref_omatcopy( char, char, gtint_t, gtint_t, dcomplex, dcomplex*, gtint_t, dcomplex*, gtint_t ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/extension/ref_omatcopy2.cpp b/gtestsuite/testinghelpers/src/extension/ref_omatcopy2.cpp new file mode 100644 index 0000000000..426b8b7f86 --- /dev/null +++ b/gtestsuite/testinghelpers/src/extension/ref_omatcopy2.cpp @@ -0,0 +1,97 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "extension/ref_omatcopy2.h" + +namespace testinghelpers { + +#if defined(REF_IS_MKL) +template +void ref_omatcopy2( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, gtint_t stridea, T* B, gtint_t ldb, gtint_t strideb ) { + + // Defining the function pointer type for the native MKL call of omatcopy2 + typedef void (*Fptr_ref_mkl_omatcopy2)( + char, char, size_t, size_t, const T, + const T *, size_t, size_t, T *, + size_t, size_t + ); + + // Function pointer to load the MKL symbol + Fptr_ref_mkl_omatcopy2 ref_mkl_omatcopy2 = nullptr; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_mkl_omatcopy2 = (Fptr_ref_mkl_omatcopy2)refCBLASModule.loadSymbol("MKL_Somatcopy2"); + } + else if (typeid(T) == typeid(double)) + { + ref_mkl_omatcopy2 = (Fptr_ref_mkl_omatcopy2)refCBLASModule.loadSymbol("MKL_Domatcopy2"); + } + else if (typeid(T) == typeid(scomplex)) + { + ref_mkl_omatcopy2 = (Fptr_ref_mkl_omatcopy2)refCBLASModule.loadSymbol("MKL_Comatcopy2"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_mkl_omatcopy2 = (Fptr_ref_mkl_omatcopy2)refCBLASModule.loadSymbol("MKL_Zomatcopy2"); + } + else + { + throw std::runtime_error("Error in ref_omatcopy2.cpp: Invalid typename is passed function template."); + } + if (!ref_mkl_omatcopy2) { + throw std::runtime_error("Error in ref_omatcopy2.cpp: Function pointer == 0 -- symbol not found."); + } + + ref_mkl_omatcopy2( storage, trans, m, n, alpha, A, lda, stridea, B, ldb, strideb ); +} +#else +template +void ref_omatcopy2( char storage, char trans, gtint_t m, gtint_t n, T alpha, T* A, + gtint_t lda, gtint_t stridea, T* B, gtint_t ldb, gtint_t strideb ) { + throw std::runtime_error("Error in ref_omatcopy2.cpp: The provided reference does not support the required operation."); +} +#endif + +// Explicit template instantiations +template void ref_omatcopy2( char, char, gtint_t, gtint_t, float, float*, gtint_t, gtint_t, float*, gtint_t, gtint_t ); +template void ref_omatcopy2( char, char, gtint_t, gtint_t, double, double*, gtint_t, gtint_t, double*, gtint_t, gtint_t ); +template void ref_omatcopy2( char, char, gtint_t, gtint_t, scomplex, scomplex*, gtint_t, gtint_t, scomplex*, gtint_t, gtint_t ); +template void ref_omatcopy2( char, char, gtint_t, gtint_t, dcomplex, dcomplex*, gtint_t, gtint_t, dcomplex*, gtint_t, gtint_t ); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/level1/ref_addv.cpp b/gtestsuite/testinghelpers/src/level1/ref_addv.cpp index 87f4c217d7..aad1ade01e 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_addv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_addv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_amaxv.cpp b/gtestsuite/testinghelpers/src/level1/ref_amaxv.cpp index 33007e0fd3..bf033322a7 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_amaxv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_amaxv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_axpbyv.cpp b/gtestsuite/testinghelpers/src/level1/ref_axpbyv.cpp index 373d31e0e1..f0615d145f 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_axpbyv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_axpbyv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -37,16 +37,27 @@ namespace testinghelpers { -#if !defined(REF_IS_OPENBLAS) || !defined(REF_IS_MKL) +#if !defined(REF_IS_OPENBLAS) && !defined(REF_IS_MKL) template void ref_axpbyv( char conj_x, gtint_t n, T alpha, const T* x, gtint_t incx, T beta, T* y, gtint_t incy ) { using scalar_t = std::conditional_t::is_complex, T&, T>; + + // Function pointer types to decompose into respective BLAS APIs + // SCALV typedef void (*Fptr_ref_cblas_scal)( f77_int, scalar_t , const T *, f77_int); + // COPYV + typedef void (*Fptr_ref_cblas_copyv)(f77_int, const T*, f77_int, T*, f77_int); + // AXPYV + typedef void (*Fptr_ref_cblas_axpy)( f77_int, scalar_t , const T *, f77_int , T *, f77_int ); + + // Function pointers to load the respective CBLAS symbols Fptr_ref_cblas_scal ref_cblas_scal; + Fptr_ref_cblas_copyv ref_cblas_copyv; + Fptr_ref_cblas_axpy ref_cblas_axpy; - // Call C function + // Loading CBLAS SCALV /* Check the typename T passed to this function template and call respective function.*/ if (typeid(T) == typeid(float)) { @@ -72,49 +83,140 @@ void ref_axpbyv( char conj_x, gtint_t n, T alpha, const T* x, throw std::runtime_error("Error in ref_axpby.cpp: Function pointer == 0 -- symbol not found."); } - ref_cblas_scal( n, beta, y, incy ); - typedef void (*Fptr_ref_cblas_axpby)( f77_int, scalar_t , const T *, f77_int , T *, f77_int ); - Fptr_ref_cblas_axpby ref_cblas_axpby; + // Loading CBLAS COPYV + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_copyv = (Fptr_ref_cblas_copyv)refCBLASModule.loadSymbol("cblas_scopy"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_copyv = (Fptr_ref_cblas_copyv)refCBLASModule.loadSymbol("cblas_dcopy"); + } + else if (typeid(T) == typeid(scomplex)) + { + ref_cblas_copyv = (Fptr_ref_cblas_copyv)refCBLASModule.loadSymbol("cblas_ccopy"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_cblas_copyv = (Fptr_ref_cblas_copyv)refCBLASModule.loadSymbol("cblas_zcopy"); + } + else + { + throw std::runtime_error("Error in ref_copyv.cpp: Invalid typename is passed function template."); + } + if (!ref_cblas_copyv) { + throw std::runtime_error("Error in ref_copyv.cpp: Function pointer == 0 -- symbol not found."); + } - // Call C function + // Loading CBLAS AXPYV /* Check the typename T passed to this function template and call respective function.*/ if (typeid(T) == typeid(float)) { - ref_cblas_axpby = (Fptr_ref_cblas_axpby)refCBLASModule.loadSymbol("cblas_saxpy"); + ref_cblas_axpy = (Fptr_ref_cblas_axpy)refCBLASModule.loadSymbol("cblas_saxpy"); } else if (typeid(T) == typeid(double)) { - ref_cblas_axpby = (Fptr_ref_cblas_axpby)refCBLASModule.loadSymbol("cblas_daxpy"); + ref_cblas_axpy = (Fptr_ref_cblas_axpy)refCBLASModule.loadSymbol("cblas_daxpy"); } else if (typeid(T) == typeid(scomplex)) { - ref_cblas_axpby = (Fptr_ref_cblas_axpby)refCBLASModule.loadSymbol("cblas_caxpy"); + ref_cblas_axpy = (Fptr_ref_cblas_axpy)refCBLASModule.loadSymbol("cblas_caxpy"); } else if (typeid(T) == typeid(dcomplex)) { - ref_cblas_axpby = (Fptr_ref_cblas_axpby)refCBLASModule.loadSymbol("cblas_zaxpy"); + ref_cblas_axpy = (Fptr_ref_cblas_axpy)refCBLASModule.loadSymbol("cblas_zaxpy"); } else { throw std::runtime_error("Error in ref_axpby.cpp: Invalid typename is passed function template."); } - if (!ref_cblas_axpby) { + if (!ref_cblas_axpy) { throw std::runtime_error("Error in ref_axpby.cpp: Function pointer == 0 -- symbol not found."); } + + // A copy of x to be used for reference computation + std::vector x_copy_vec( testinghelpers::buff_dim(n, incx) ); + memcpy( x_copy_vec.data(), x, testinghelpers::buff_dim(n, incx)*sizeof(T) ); + #ifdef TEST_BLIS_TYPED if( chkconj( conj_x ) ) { - std::vector X( testinghelpers::buff_dim(n, incx) ); - memcpy( X.data(), x, testinghelpers::buff_dim(n, incx)*sizeof(T) ); - testinghelpers::conj( X.data(), n, incx ); - ref_cblas_axpby( n, alpha, X.data(), incx, y, incy ); + testinghelpers::conj( x_copy_vec.data(), n, incx ); } - else #endif + + T * x_copy = x_copy_vec.data(); + // Decomposing using BLAS APIs + if( beta == testinghelpers::ZERO() ) { - ref_cblas_axpby( n, alpha, x, incx, y, incy ); + // Like SETV + if( alpha == testinghelpers::ZERO() ) + { + for( gtint_t i = 0; i < n; i += 1 ) + *( y + i * std::abs( incy ) ) = alpha; + } + // Like COPYV + else if ( alpha == testinghelpers::ONE() ) + { + ref_cblas_copyv( n, x_copy, incx, y, incy ); + } + // Like SCALV + COPYV + else + { + ref_cblas_scal( n, alpha, x_copy, std::abs(incx) ); + ref_cblas_copyv( n, x_copy, incx, y, incy ); + } } + else if( beta == testinghelpers::ONE() ) + { + // ERS condition + if( alpha == testinghelpers::ZERO() ) + { + return; + } + // Like ADDV + else if ( alpha == testinghelpers::ONE() ) + { + // Adjusting the pointers based on the increment sign + T *yp = ( incy < 0 )? y + ( 1 - n )*( incy ) : y; + T *xp = ( incx < 0 )? x_copy + ( 1 - n )*( incx ) : x_copy; + for( gtint_t i = 0; i < n; i += 1 ) + *( yp + i * incy ) = *( xp + i * incx ) + *( yp + i * incy ); + } + // Like AXPYV + else + { + ref_cblas_axpy( n, alpha, x_copy, incx, y, incy ); + } + } + else + { + // Like SCALV + if( alpha == testinghelpers::ZERO() ) + { + ref_cblas_scal( n, beta, y, std::abs(incy) ); + } + // Like SCALV + ADDV + else if ( alpha == testinghelpers::ONE() ) + { + ref_cblas_scal( n, beta, y, std::abs(incy) ); + + // Adjusting the pointers based on the increment sign + T *yp = ( incy < 0 )? y + ( 1 - n )*( incy ) : y; + T *xp = ( incx < 0 )? x_copy + ( 1 - n )*( incx ) : x_copy; + + for( gtint_t i = 0; i < n; i += 1 ) + *( yp + i * incy ) = *( xp + i * incx ) + *( yp + i * incy ); + } + // Like SCALV + AXPYV + else + { + ref_cblas_scal( n, beta, y, std::abs(incy) ); + ref_cblas_axpy( n, alpha, x_copy, incx, y, incy ); + } + } } #else template diff --git a/gtestsuite/testinghelpers/src/level1/ref_axpyf.cpp b/gtestsuite/testinghelpers/src/level1/ref_axpyf.cpp new file mode 100644 index 0000000000..3b87b11b2f --- /dev/null +++ b/gtestsuite/testinghelpers/src/level1/ref_axpyf.cpp @@ -0,0 +1,162 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "level1/ref_axpyv.h" +#include "level1/ref_axpyf.h" + + +namespace testinghelpers { + +float bli_cpyscal(conj_t conjx, float *chi1, float *alpha ) +{ + float alpha_chi1; + bli_scopycjs( conjx, *chi1, alpha_chi1 ); + bli_sscals( *alpha, alpha_chi1 ); + return alpha_chi1; +} + +double bli_cpyscal(conj_t conjx, double *chi1, double *alpha ) +{ + double alpha_chi1; + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + return alpha_chi1; +} + +scomplex bli_cpyscal(conj_t conjx, scomplex *chi1, scomplex *alpha ) +{ + scomplex alpha_chi1; + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + return alpha_chi1; +} + +dcomplex bli_cpyscal(conj_t conjx, dcomplex *chi1, dcomplex *alpha ) +{ + dcomplex alpha_chi1; + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + return alpha_chi1; +} + +template +void ref_axpyf( char conja, + char conjx, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T* y, + gtint_t incy + ) + { + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + for (gtint_t i = 0; i < b; ++i ) + { + T* a1 = A + (0 )*inca + (i )*lda; + T* chi1 = x + (i )*incx; + T* y1 = y + (0 )*incy; + + T alpha_chi1 = bli_cpyscal( blis_conjx, chi1, alpha ); + + testinghelpers::ref_axpyv( conja, m, alpha_chi1, a1, inca, y1, incy ); + } + } + +template void ref_axpyf( + char conja, + char conjx, + gtint_t m, + gtint_t b, + float *alpha, + float* A, + gtint_t inca, + gtint_t lda, + float* x, + gtint_t incx, + float* y, + gtint_t incy + ); + +template void ref_axpyf( + char conja, + char conjx, + gtint_t m, + gtint_t b, + double *alpha, + double* A, + gtint_t inca, + gtint_t lda, + double* x, + gtint_t incx, + double* y, + gtint_t incy + ); + +template void ref_axpyf( + char conja, + char conjx, + gtint_t m, + gtint_t b, + scomplex *alpha, + scomplex* A, + gtint_t inca, + gtint_t lda, + scomplex* x, + gtint_t incx, + scomplex* y, + gtint_t incy + ); + +template void ref_axpyf( + char conja, + char conjx, + gtint_t m, + gtint_t b, + dcomplex *alpha, + dcomplex* A, + gtint_t inca, + gtint_t lda, + dcomplex* x, + gtint_t incx, + dcomplex* y, + gtint_t incy + ); +} diff --git a/gtestsuite/testinghelpers/src/level1/ref_axpyv.cpp b/gtestsuite/testinghelpers/src/level1/ref_axpyv.cpp index 750ac04172..9423794139 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_axpyv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_axpyv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_copyv.cpp b/gtestsuite/testinghelpers/src/level1/ref_copyv.cpp index 4539ab551c..a93979a81c 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_copyv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_copyv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_dotv.cpp b/gtestsuite/testinghelpers/src/level1/ref_dotv.cpp index 35c4b5ec5c..34eaac2789 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_dotv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_dotv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -69,7 +69,6 @@ void ref_dotv(gtint_t len, const T* xp, template void ref_dotv( char conj_x, char conj_y, gtint_t len, const T* xp, gtint_t incx, const T* yp, gtint_t incy, T* rho ) { - typedef void (*Fptr_ref_cblas_dot)(f77_int, const T*, f77_int, const T*, f77_int, T* ); Fptr_ref_cblas_dot ref_cblas_dot; @@ -85,11 +84,11 @@ void ref_dotv( char conj_x, char conj_y, gtint_t len, const T* xp, gtint_t incx, memcpy(Y.data(), yp, svy*sizeof(T)); if( cfx ) { - conj( X.data(), len, incx ); + conj( X.data(), len, abs(incx) ); } if( cfy ) { - conj( Y.data(), len, incy ); + conj( Y.data(), len, abs(incy) ); } // Call C function diff --git a/gtestsuite/testinghelpers/src/level1/ref_dotxf.cpp b/gtestsuite/testinghelpers/src/level1/ref_dotxf.cpp new file mode 100644 index 0000000000..d732723d36 --- /dev/null +++ b/gtestsuite/testinghelpers/src/level1/ref_dotxf.cpp @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "level1/ref_dotxv.h" +#include "level1/ref_dotxf.h" + +/** + * dotxf operation is defined as : + * y := y + alpha * conja(A) * conjx(x) + * where A is an m x b matrix, and y and x are vectors. + */ +namespace testinghelpers { +template +void ref_dotxf( char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T * beta, + T* y, + gtint_t incy + ) + { + for ( dim_t i = 0; i < b; ++i ) + { + T* a1 = A + (0 )*inca + (i )*lda; + T* x1 = x + (0 )*incx; + T* psi1 = y + (i )*incy; + + testinghelpers::ref_dotxv + ( + conj_a, + conj_x, + m, + *alpha, + a1, inca, + x1, incx, + *beta, + psi1 + ); + } + } + +template void ref_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + double *alpha, + double* A, + gtint_t inca, + gtint_t lda, + double* x, + gtint_t incx, + double *beta, + double* y, + gtint_t incy + ); + +template void ref_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + float *alpha, + float* A, + gtint_t inca, + gtint_t lda, + float* x, + gtint_t incx, + float *beta, + float* y, + gtint_t incy + ); + +template void ref_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + scomplex *alpha, + scomplex* A, + gtint_t inca, + gtint_t lda, + scomplex* x, + gtint_t incx, + scomplex *beta, + scomplex* y, + gtint_t incy + ); + +template void ref_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + dcomplex *alpha, + dcomplex* A, + gtint_t inca, + gtint_t lda, + dcomplex* x, + gtint_t incx, + dcomplex *beta, + dcomplex* y, + gtint_t incy + ); +} diff --git a/gtestsuite/testinghelpers/src/level1/ref_dotxv.cpp b/gtestsuite/testinghelpers/src/level1/ref_dotxv.cpp index 1d08c4d438..76bad1a30a 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_dotxv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_dotxv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_scal2v.cpp b/gtestsuite/testinghelpers/src/level1/ref_scal2v.cpp index 34ea17dc1c..47b22f768c 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_scal2v.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_scal2v.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_scalv.cpp b/gtestsuite/testinghelpers/src/level1/ref_scalv.cpp index 5b74b91b25..432304e314 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_scalv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_scalv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -37,43 +37,38 @@ namespace testinghelpers { -template -void ref_scalv(char conjalpha, gtint_t n, T alpha, T* x, gtint_t incx) +template +void ref_scalv(char conjalpha, gtint_t n, U alpha, T* x, gtint_t incx) { - using scalar_t = std::conditional_t::is_complex, T&, T>; + using scalar_t = std::conditional_t::is_complex, U&, U>; typedef void (*Fptr_ref_cblas_scal)( f77_int, scalar_t , T *, f77_int); Fptr_ref_cblas_scal ref_cblas_scal; - // Call C function - /* Check the typename T passed to this function template and call respective function.*/ - if (typeid(T) == typeid(float)) - { - ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_sscal"); - } - else if (typeid(T) == typeid(double)) - { - ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_dscal"); - } - else if (typeid(T) == typeid(scomplex)) - { - ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_cscal"); - } - else if (typeid(T) == typeid(dcomplex)) - { - ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_zscal"); - } + if constexpr (std::is_same::value) + if constexpr (std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_sscal"); + else if constexpr (std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_dscal"); + else if constexpr (std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_cscal"); + else if constexpr (std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_zscal"); + else + throw std::runtime_error("Error in ref_scalv.cpp: Invalid typename is passed function template."); + else if constexpr (std::is_same:: value && std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_csscal"); + else if constexpr (std::is_same:: value && std::is_same::value) + ref_cblas_scal = (Fptr_ref_cblas_scal)refCBLASModule.loadSymbol("cblas_zdscal"); else - { throw std::runtime_error("Error in ref_scalv.cpp: Invalid typename is passed function template."); - } - if (!ref_cblas_scal) { + + if (!ref_cblas_scal) throw std::runtime_error("Error in ref_scalv.cpp: Function pointer == 0 -- symbol not found."); - } #ifdef TEST_BLIS_TYPED if( chkconj( conjalpha ) ) { - T alpha_conj = testinghelpers::conj( alpha ); + U alpha_conj = testinghelpers::conj( alpha ); ref_cblas_scal( n, alpha_conj, x, incx ); } else @@ -81,13 +76,14 @@ void ref_scalv(char conjalpha, gtint_t n, T alpha, T* x, gtint_t incx) { ref_cblas_scal( n, alpha, x, incx ); } - } // Explicit template instantiations -template void ref_scalv(char, gtint_t, float, float*, gtint_t); -template void ref_scalv(char, gtint_t, double, double*, gtint_t); -template void ref_scalv(char, gtint_t, scomplex, scomplex*, gtint_t); -template void ref_scalv(char, gtint_t, dcomplex, dcomplex*, gtint_t); +template void ref_scalv(char, gtint_t, float, float*, gtint_t); +template void ref_scalv(char, gtint_t, double, double*, gtint_t); +template void ref_scalv(char, gtint_t, scomplex, scomplex*, gtint_t); +template void ref_scalv(char, gtint_t, dcomplex, dcomplex*, gtint_t); +template void ref_scalv(char, gtint_t, float, scomplex*, gtint_t); +template void ref_scalv(char, gtint_t, double, dcomplex*, gtint_t); } //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/level1/ref_subv.cpp b/gtestsuite/testinghelpers/src/level1/ref_subv.cpp index 40ddb3e02c..b9f55d177e 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_subv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_subv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level1/ref_swapv.cpp b/gtestsuite/testinghelpers/src/level1/ref_swapv.cpp new file mode 100644 index 0000000000..e7aee37311 --- /dev/null +++ b/gtestsuite/testinghelpers/src/level1/ref_swapv.cpp @@ -0,0 +1,69 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "level1/ref_swapv.h" + +namespace testinghelpers { + +template +void ref_swapv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + typedef void (*Fptr_ref_cblas_swapv)( f77_int, T *, f77_int, T *, f77_int); + Fptr_ref_cblas_swapv ref_cblas_swapv; + + if (typeid(T) == typeid(float)) + ref_cblas_swapv = (Fptr_ref_cblas_swapv)refCBLASModule.loadSymbol("cblas_sswap"); + else if (typeid(T) == typeid(double)) + ref_cblas_swapv = (Fptr_ref_cblas_swapv)refCBLASModule.loadSymbol("cblas_dswap"); + else if (typeid(T) == typeid(scomplex)) + ref_cblas_swapv = (Fptr_ref_cblas_swapv)refCBLASModule.loadSymbol("cblas_cswap"); + else if (typeid(T) == typeid(dcomplex)) + ref_cblas_swapv = (Fptr_ref_cblas_swapv)refCBLASModule.loadSymbol("cblas_zswap"); + else + throw std::runtime_error("Error in ref_swapv.cpp: Invalid typename is passed function template."); + + if (!ref_cblas_swapv) + throw std::runtime_error("Error in ref_swapv.cpp: Function pointer == 0 -- symbol not found."); + + ref_cblas_swapv( n, x, incx, y, incy ); +} + +// Explicit template instantiations +template void ref_swapv(gtint_t, float*, gtint_t, float*, gtint_t); +template void ref_swapv(gtint_t, double*, gtint_t, double*, gtint_t); +template void ref_swapv(gtint_t, scomplex*, gtint_t, scomplex*, gtint_t); +template void ref_swapv(gtint_t, dcomplex*, gtint_t, dcomplex*, gtint_t); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/level1/ref_xpbyv.cpp b/gtestsuite/testinghelpers/src/level1/ref_xpbyv.cpp index d8f30dea64..549bd3f8e1 100644 --- a/gtestsuite/testinghelpers/src/level1/ref_xpbyv.cpp +++ b/gtestsuite/testinghelpers/src/level1/ref_xpbyv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_gemv.cpp b/gtestsuite/testinghelpers/src/level2/ref_gemv.cpp index fac8e661db..4fc101ec32 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_gemv.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_gemv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_ger.cpp b/gtestsuite/testinghelpers/src/level2/ref_ger.cpp index 60857cce5c..7c5453d74e 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_ger.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_ger.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_hemv.cpp b/gtestsuite/testinghelpers/src/level2/ref_hemv.cpp index 13e7996ab2..70471c39bb 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_hemv.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_hemv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_her.cpp b/gtestsuite/testinghelpers/src/level2/ref_her.cpp index b9a078b7f1..1e3bc09945 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_her.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_her.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_her2.cpp b/gtestsuite/testinghelpers/src/level2/ref_her2.cpp index fe078008ce..0c5f4d7d58 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_her2.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_her2.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_symv.cpp b/gtestsuite/testinghelpers/src/level2/ref_symv.cpp index ae976d2580..79c874f925 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_symv.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_symv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_syr.cpp b/gtestsuite/testinghelpers/src/level2/ref_syr.cpp index c5648cc23f..e8032af587 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_syr.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_syr.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_syr2.cpp b/gtestsuite/testinghelpers/src/level2/ref_syr2.cpp index fe593d1c41..ea9236d3a4 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_syr2.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_syr2.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_trmv.cpp b/gtestsuite/testinghelpers/src/level2/ref_trmv.cpp index 1e18b35e15..f331783322 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_trmv.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_trmv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level2/ref_trsv.cpp b/gtestsuite/testinghelpers/src/level2/ref_trsv.cpp index 5d92a3c3e4..72059e2044 100644 --- a/gtestsuite/testinghelpers/src/level2/ref_trsv.cpp +++ b/gtestsuite/testinghelpers/src/level2/ref_trsv.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_gemm.cpp b/gtestsuite/testinghelpers/src/level3/ref_gemm.cpp index 52589ff233..a938d0ba49 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_gemm.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_gemm.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp index 21c055f9dd..c1bd8e7f73 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_gemm_compute.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -82,7 +82,7 @@ void ref_gemm_compute(char storage, char trnsa, char trnsb, char pcka, char pckb using scalar_t = std::conditional_t::is_complex, T&, T>; - typedef gint_t (*Fptr_ref_cblas_gemm_pack_get_size)( const CBLAS_IDENTIFIER, + typedef gtint_t (*Fptr_ref_cblas_gemm_pack_get_size)( const CBLAS_IDENTIFIER, const f77_int, const f77_int, const f77_int ); Fptr_ref_cblas_gemm_pack_get_size ref_cblas_gemm_pack_get_size; diff --git a/gtestsuite/testinghelpers/src/level3/ref_gemmt.cpp b/gtestsuite/testinghelpers/src/level3/ref_gemmt.cpp index 8d260aefb6..8c21cab543 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_gemmt.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_gemmt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -49,7 +49,8 @@ **/ namespace testinghelpers { -#if 1 + +#if defined(REF_IS_NETLIB) template void ref_gemmt ( char storage, char uplo, char trnsa, char trnsb, diff --git a/gtestsuite/testinghelpers/src/level3/ref_hemm.cpp b/gtestsuite/testinghelpers/src/level3/ref_hemm.cpp index 45dce9ca43..afb2e3cf7e 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_hemm.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_hemm.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_her2k.cpp b/gtestsuite/testinghelpers/src/level3/ref_her2k.cpp index 25030d7d42..6bfa9fdd59 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_her2k.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_her2k.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_herk.cpp b/gtestsuite/testinghelpers/src/level3/ref_herk.cpp index 6516833d88..064d23bf53 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_herk.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_herk.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_symm.cpp b/gtestsuite/testinghelpers/src/level3/ref_symm.cpp index fa13613327..e232132a13 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_symm.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_symm.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_syr2k.cpp b/gtestsuite/testinghelpers/src/level3/ref_syr2k.cpp index 41ae007f6a..7c4308ef1a 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_syr2k.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_syr2k.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_syrk.cpp b/gtestsuite/testinghelpers/src/level3/ref_syrk.cpp index 6a1d009cb4..f08ac3efb4 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_syrk.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_syrk.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_trmm.cpp b/gtestsuite/testinghelpers/src/level3/ref_trmm.cpp index 0faa1e52fb..305fae7e40 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_trmm.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_trmm.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_trmm3.cpp b/gtestsuite/testinghelpers/src/level3/ref_trmm3.cpp index cb6e1283d2..24e852249b 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_trmm3.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_trmm3.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/level3/ref_trsm.cpp b/gtestsuite/testinghelpers/src/level3/ref_trsm.cpp index 6f56c069e1..c24848093b 100644 --- a/gtestsuite/testinghelpers/src/level3/ref_trsm.cpp +++ b/gtestsuite/testinghelpers/src/level3/ref_trsm.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testinghelpers/src/util/ref_asumv.cpp b/gtestsuite/testinghelpers/src/util/ref_asumv.cpp new file mode 100644 index 0000000000..4051c450a9 --- /dev/null +++ b/gtestsuite/testinghelpers/src/util/ref_asumv.cpp @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "util/ref_asumv.h" + +/* + * ========================================================================== + * ASUMV computes the sum of the absolute values of the fundamental elements + * of vector x. + * ========================================================================== +**/ + +namespace testinghelpers { + +template +RT ref_asumv(gtint_t n, T* x, gtint_t incx) { + + typedef RT (*Fptr_ref_cblas_asum)( f77_int, const T *, f77_int ); + Fptr_ref_cblas_asum ref_cblas_asum; + + // Call C function + /* Check the typename T passed to this function template and call respective function.*/ + if (typeid(T) == typeid(float)) + { + ref_cblas_asum = (Fptr_ref_cblas_asum)refCBLASModule.loadSymbol("cblas_sasum"); + } + else if (typeid(T) == typeid(double)) + { + ref_cblas_asum = (Fptr_ref_cblas_asum)refCBLASModule.loadSymbol("cblas_dasum"); + } + else if (typeid(T) == typeid(scomplex)) + { + ref_cblas_asum = (Fptr_ref_cblas_asum)refCBLASModule.loadSymbol("cblas_scasum"); + } + else if (typeid(T) == typeid(dcomplex)) + { + ref_cblas_asum = (Fptr_ref_cblas_asum)refCBLASModule.loadSymbol("cblas_dzasum"); + } + else + { + throw std::runtime_error("Error in ref_asumv.cpp: Invalid typename is passed function template."); + } + if (!ref_cblas_asum) { + throw std::runtime_error("Error in ref_asumv.cpp: Function pointer == 0 -- symbol not found."); + } + + return ref_cblas_asum(n, x, incx); +} + +// Explicit template instantiations +template float ref_asumv< float, float>(gtint_t n, float* x, gtint_t incx); +template double ref_asumv< double, double>(gtint_t n, double* x, gtint_t incx); +template float ref_asumv(gtint_t n, scomplex* x, gtint_t incx); +template double ref_asumv(gtint_t n, dcomplex* x, gtint_t incx); + +} //end of namespace testinghelpers diff --git a/gtestsuite/testinghelpers/src/util/ref_nrm2.cpp b/gtestsuite/testinghelpers/src/util/ref_nrm2.cpp index 95bc2e1e93..7b2272f784 100644 --- a/gtestsuite/testinghelpers/src/util/ref_nrm2.cpp +++ b/gtestsuite/testinghelpers/src/util/ref_nrm2.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/gtestsuite/testsuite/CMakeLists.txt b/gtestsuite/testsuite/CMakeLists.txt index ece8c8434a..8fc1197376 100644 --- a/gtestsuite/testsuite/CMakeLists.txt +++ b/gtestsuite/testsuite/CMakeLists.txt @@ -1,21 +1,22 @@ #[=[ + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ]=] # Fetch and Build GTest at configure time @@ -63,60 +65,119 @@ if(REF_CBLAS STREQUAL "MKL") endif() endif() -# Return the list of the subdirectories in the directory curdir. -MACRO(SUBDIRLIST result curdir) - FILE(GLOB children RELATIVE ${curdir} ${curdir}/*) - SET(dirlist "") - FOREACH(child ${children}) - IF(IS_DIRECTORY ${curdir}/${child}) - LIST(APPEND dirlist ${child}) - ENDIF() - ENDFOREACH() - SET(${result} ${dirlist}) -ENDMACRO() - -SUBDIRLIST(DIRS ${CMAKE_CURRENT_SOURCE_DIR}) - -set(target_name "testsuite") -foreach(dir ${DIRS}) - add_custom_target(${target_name}.${dir}) - SUBDIRLIST(SUBDIRS ${CMAKE_CURRENT_SOURCE_DIR}/${dir}) - foreach(subdir ${SUBDIRS}) - file(GLOB files ${CMAKE_CURRENT_SOURCE_DIR}/${dir}/${subdir}/*.cpp) - if(files) - add_executable(${target_name}.${dir}.${subdir} ${files}) - set_target_properties(${target_name}.${dir}.${subdir} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) - set_target_properties(${target_name}.${dir}.${subdir} PROPERTIES OUTPUT_NAME ${target_name}.${dir}.${subdir}) - target_include_directories(${target_name}.${dir}.${subdir} PUBLIC ${BLIS_INCLUDE} ${CMAKE_SOURCE_DIR}/testinghelpers/inc ${CMAKE_SOURCE_DIR}/testsuite/) - target_link_libraries(${target_name}.${dir}.${subdir} gtest gtest_main testinghelpers ${BLIS_LIBRARY} ${COMMON_LIBS}) - # if we test serial BLIS, but MKL is used as a reference we still need to set up OpenMP. - if( (ENABLE_THREADING STREQUAL "openmp") OR (MKL_ENABLE_THREADING STREQUAL "openmp")) - target_link_libraries(${target_name}.${dir}.${subdir} OpenMP::OpenMP_CXX) +# Note: Once we integrate with the blis CMake system, we will update and use +# this functionality from the build/cmake directory. +#-------------------------------------------- +# Important sets of header files and paths +#-------------------------------------------- +# Get a list of all sub-directories of a given directory +macro(get_dirpaths_with_suffixes result curdir sufflist) + set(dirlist "") + # dirlist will have all files which are below this directory. + file(GLOB_RECURSE children LIST_DIRECTORIES true ${curdir}/*) + # Adding current directory in the list. + list(PREPEND children ${curdir}) + # Filter out anything that is not a directory. + foreach(child ${children}) + if(IS_DIRECTORY ${child}) + set(HAS_SUFF_FILE "false") + foreach(suff ${sufflist}) + file(GLOB suff_files LIST_DIRECTORIES false ${child}/*\.${suff}) + list(LENGTH suff_files list_size) + if(NOT (${list_size} STREQUAL 0)) + set(HAS_SUFF_FILE "true") + # If there is at least one file with a specific suffix break from for-loop. + break() + endif() + endforeach() + # If there is at least one *.suff file, add directory path in the list. + if(HAS_SUFF_FILE STREQUAL "true") + list(APPEND dirlist "${child}") + endif() endif() - if(ENABLE_ASAN) - target_link_libraries(${target_name}.${dir}.${subdir} -fsanitize=address) - endif() - if(ENABLE_COVERAGE) - target_link_libraries(${target_name}.${dir}.${subdir} "--coverage") - endif() - if(TEST_INTERFACE STREQUAL "BLAS") - target_compile_definitions(${target_name}.${dir}.${subdir} PUBLIC TEST_BLAS) - elseif(TEST_INTERFACE STREQUAL "CBLAS") - target_compile_definitions(${target_name}.${dir}.${subdir} PUBLIC TEST_CBLAS) - else() # BLIS_TYPED option - target_compile_definitions(${target_name}.${dir}.${subdir} PUBLIC TEST_BLIS_TYPED) - endif() - target_compile_definitions(${target_name}.${dir}.${subdir} PUBLIC BLIS_ELEMENT_TYPE='${BLIS_ELEMENT_TYPE}') - add_test(NAME ${target_name}.${dir}.${subdir} COMMAND ${target_name}.${dir}.${subdir}) - if(REF_CBLAS STREQUAL "MKL") - set_property(TEST ${target_name}.${dir}.${subdir} PROPERTY ENVIRONMENT ${MKL_ENV}) - endif() - if(BLIS_LINKING_TYPE STREQUAL "shared") - set_property(TEST ${target_name}.${dir}.${subdir} PROPERTY ENVIRONMENT_MODIFICATION "PATH=path_list_prepend:${BLIS_LIB_PATH}") - endif() - add_dependencies(${target_name}.${dir} ${target_name}.${dir}.${subdir}) - endif() endforeach() + # Get the name of the current directory, after removing the source directory + # from the name, so that we can exclude the files that are part of the ignore + # list even if the blis directory is located in a directory with a name that + # would be ignored. + string(REPLACE "${CMAKE_SOURCE_DIR}/" "" curdirsimple ${curdir}) + # Filter out anything that is part of the IGNORE_LIST. + foreach(item ${IGNORE_LIST}) + list(FILTER dirlist EXCLUDE REGEX ${curdirsimple}.*/${item}/) + endforeach() + list(APPEND ${result} ${dirlist}) +endmacro() + +get_dirpaths_with_suffixes(test_files ${CMAKE_CURRENT_SOURCE_DIR} cpp) +set(target_name "testsuite") +foreach(dir ${test_files}) + file(GLOB files ${dir}/*.cpp) + STRING(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/" "" exec_name ${dir}) + STRING(REPLACE "/" "." exec_name ${exec_name}) + STRING(PREPEND exec_name ${target_name}.) + if(files) + add_executable(${exec_name} ${files}) + set_target_properties(${exec_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set_target_properties(${exec_name} PROPERTIES OUTPUT_NAME ${exec_name}) + target_include_directories(${exec_name} PUBLIC ${BLIS_INCLUDE} ${CMAKE_SOURCE_DIR}/testinghelpers/inc ${CMAKE_SOURCE_DIR}/testsuite/) + target_link_libraries(${exec_name} gtest gtest_main testinghelpers ${BLIS_LIBRARY} ${COMMON_LIBS}) + # if we test serial BLIS, but MKL is used as a reference we still need to set up OpenMP. + if( (ENABLE_THREADING STREQUAL "openmp") OR (MKL_ENABLE_THREADING STREQUAL "openmp")) + target_link_libraries(${exec_name} OpenMP::OpenMP_CXX) + endif() + target_link_libraries(${exec_name} ${ASAN_FLAGS} ${COVERAGE_FLAGS}) + if(TEST_INTERFACE STREQUAL "BLAS") + target_compile_definitions(${exec_name} PUBLIC TEST_BLAS TEST_BLAS_LIKE API_PRINT="blas") + elseif(TEST_INTERFACE STREQUAL "BLAS_BLIS_IMPL") + target_compile_definitions(${exec_name} PUBLIC TEST_BLAS_BLIS_IMPL TEST_BLAS_LIKE API_PRINT="blas_blis_impl") + elseif(TEST_INTERFACE STREQUAL "CBLAS") + target_compile_definitions(${exec_name} PUBLIC TEST_CBLAS API_PRINT="cblas") + else() # BLIS_TYPED option + target_compile_definitions(${exec_name} PUBLIC TEST_BLIS_TYPED API_PRINT="bli") + endif() + target_compile_definitions(${exec_name} PUBLIC ${UKR_DEFINES}) + if(TEST_UPPERCASE_ARGS) + target_compile_definitions(${exec_name} PUBLIC TEST_UPPERCASE_ARGS) + endif() + if(THRESHOLD_ZERO) + target_compile_definitions(${exec_name} PUBLIC THRESHOLD_ZERO) + endif() + if(CAN_TEST_INFO_VALUE) + target_compile_definitions(${exec_name} PUBLIC CAN_TEST_INFO_VALUE) + endif() + if(TEST_INPUT_ARGS) + target_compile_definitions(${exec_name} PUBLIC TEST_INPUT_ARGS) + endif() + add_test(NAME ${exec_name} COMMAND ${exec_name}) + if(REF_CBLAS STREQUAL "MKL") + set_property(TEST ${exec_name} PROPERTY ENVIRONMENT ${MKL_ENV}) + endif() + if(BLIS_LINKING_TYPE STREQUAL "shared") + set_property(TEST ${exec_name} PROPERTY ENVIRONMENT_MODIFICATION "PATH=path_list_prepend:${BLIS_LIB_PATH}") + endif() + endif() + list(APPEND all_execs ${exec_name}) endforeach() +# Return the list of the subdirectories in the directory curdir. +macro(SUBDIRLIST result curdir) + file(GLOB children RELATIVE ${curdir} ${curdir}/*) + set(dirlist "") + foreach(child ${children}) + if(IS_DIRECTORY ${curdir}/${child}) + list(APPEND dirlist ${child}) + ENDIF() + endforeach() + set(${result} ${dirlist}) +endmacro() +# Add dependencies to build all level1 or level2, etc., tests with one target. +SUBDIRLIST(subdirs ${CMAKE_CURRENT_SOURCE_DIR}) +foreach(dir ${subdirs}) + set(child_execs ${all_execs}) + add_custom_target(${target_name}.${dir}) + list(FILTER child_execs INCLUDE REGEX ${dir}) + foreach(child ${child_execs}) + add_dependencies(${target_name}.${dir} ${child}) + endforeach() +endforeach() diff --git a/gtestsuite/testsuite/extension/imatcopy/cimatcopy_evt.cpp b/gtestsuite/testsuite/extension/imatcopy/cimatcopy_evt.cpp new file mode 100644 index 0000000000..498bd9282c --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/cimatcopy_evt.cpp @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class cimatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cimatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( cimatcopyEVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_imatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for cimatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + cimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{1.0, 0.0}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); + +// EVT testing for cimatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + cimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(scomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/cimatcopy_generic.cpp b/gtestsuite/testsuite/extension/imatcopy/cimatcopy_generic.cpp new file mode 100644 index 0000000000..7c99c045e2 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/cimatcopy_generic.cpp @@ -0,0 +1,106 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class cimatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cimatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( cimatcopyGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_in_inc for A + gtint_t lda_in_inc = std::get<5>(GetParam()); + // ldb_out_inc for A + gtint_t lda_out_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_imatcopy( storage, trans, m, n, alpha, lda_in_inc, lda_out_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of cimatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + cimatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{-3.1, 1.7}, + scomplex{1.0, 0.0}, scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::imatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/dimatcopy_evt.cpp b/gtestsuite/testsuite/extension/imatcopy/dimatcopy_evt.cpp new file mode 100644 index 0000000000..5960c266fa --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/dimatcopy_evt.cpp @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class dimatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dimatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( dimatcopyEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_imatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for dimatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + dimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); + +// EVT testing for dimatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + dimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(0.0), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/dimatcopy_generic.cpp b/gtestsuite/testsuite/extension/imatcopy/dimatcopy_generic.cpp new file mode 100644 index 0000000000..194bda90ff --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/dimatcopy_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class dimatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dimatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( dimatcopyGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_in_inc for A + gtint_t lda_in_inc = std::get<5>(GetParam()); + // ldb_out_inc for A + gtint_t lda_out_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_imatcopy( storage, trans, m, n, alpha, lda_in_inc, lda_out_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of dimatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + dimatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::imatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/imatcopy.h b/gtestsuite/testsuite/extension/imatcopy/imatcopy.h new file mode 100644 index 0000000000..b2d648d475 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/imatcopy.h @@ -0,0 +1,103 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +/** + * @brief Performs the operation: + * A := alpha * op(A), + * where op(A) could be A, A(transpose), A(conjugate), A(conjugate-transpose) + * @param[in] m number of rows in A, number of rows/columns in B + * @param[in] m number of columns in A, number of columns/rows in B + * @param[in] alpha scalar + * @param[in] A pointer which points to the first element of A matrix + * @param[in] lda_in leading dimension of A(input) matrix + * @param[in] lda_out leading dimension of A(output) matrix + */ + +template +static void imatcopy_( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda_in, gtint_t lda_out ) +{ + if constexpr (std::is_same::value) + simatcopy_( &trans, &m, &n, (const float *)&alpha, A, &lda_in, &lda_out ); + else if constexpr (std::is_same::value) + dimatcopy_( &trans, &m, &n, (const double *)&alpha, A, &lda_in, &lda_out ); + else if constexpr (std::is_same::value) + cimatcopy_( &trans, &m, &n, (const scomplex *)&alpha, A, &lda_in, &lda_out ); + else if constexpr (std::is_same::value) + zimatcopy_( &trans, &m, &n, (const dcomplex *)&alpha, A, &lda_in, &lda_out ); + else + throw std::runtime_error("Error in testsuite/level1/imatcopy.h: Invalid typename in imatcopy_()."); +} + +template +static void imatcopy( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda_in, gtint_t lda_out ) +{ +#ifdef TEST_UPPERCASE_ARGS + trans = static_cast(std::toupper(static_cast(trans))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char trans_cpy = trans; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t lda_in_cpy = lda_in; + gtint_t lda_out_cpy = lda_out; +#endif + +#ifdef TEST_BLAS_LIKE + imatcopy_( trans, m, n, alpha, A, lda_in, lda_out ); +#else + throw std::runtime_error("Error in testsuite/level1/imatcopy.h: No interfaces are set to be tested."); +#endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "trans", trans, trans_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "lda_in", lda_in, lda_in_cpy ); + computediff( "lda_out", lda_out, lda_out_cpy ); +#endif +} diff --git a/gtestsuite/testsuite/extension/imatcopy/imatcopy_IIT_ERS.cpp b/gtestsuite/testsuite/extension/imatcopy/imatcopy_IIT_ERS.cpp new file mode 100644 index 0000000000..20af123264 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/imatcopy_IIT_ERS.cpp @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class imatcopy_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(imatcopy_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) + +/* + Incorrect Input Testing(IIT) + + The exceptions get triggered in the following cases: + 1. When TRANS != 'n' || TRANS != 't' || TRANS != 'c' || TRANS != 'r' + 2. When m < 0 + 3. When n < 0 + 4. When lda_in < max(1, m). + 5. When lda_out < max(1, thresh), thresh set based on TRANS value +*/ + +// When TRANS is invalid +TYPED_TEST(imatcopy_IIT_ERS, invalid_transa) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( 'Q', M, N, alpha, nullptr, LDA, LDA ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid value for TRANS value for the operation. + imatcopy( 'Q', M, N, alpha, A.data(), LDA, LDA ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', M, N, A.data(), A_ref.data(), LDA ); +} + +// When m < 0 +TYPED_TEST(imatcopy_IIT_ERS, m_lt_zero) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( TRANS, -1, N, alpha, nullptr, LDA, LDA ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid m for the operation. + imatcopy( TRANS, -1, N, alpha, A.data(), LDA, LDA ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', M, N, A.data(), A_ref.data(), LDA ); +} + +// When n < 0 +TYPED_TEST(imatcopy_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( TRANS, M, -1, alpha, nullptr, LDA, LDA ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid n for the operation. + imatcopy( TRANS, M, -1, alpha, A.data(), LDA, LDA ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', M, N, A.data(), A_ref.data(), LDA ); +} + +// When lda < m +TYPED_TEST(imatcopy_IIT_ERS, invalid_lda_in) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Having different values for m and n + gtint_t m = 10; + gtint_t n = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( TRANS, m, n, alpha, nullptr, m - 1, m ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid lda for the operation. + imatcopy( 'n', m, n, alpha, A.data(), m - 1, m ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', m, n, A.data(), A_ref.data(), m ); +} + +// When lda_out < m, with trans == 'n' +TYPED_TEST(imatcopy_IIT_ERS, invalid_lda_out_no_transpose) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Having different values for m and n + gtint_t m = 10; + gtint_t n = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( 'n', m, n, alpha, nullptr, m, m-1 ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid lda for the operation. + imatcopy( 'n', m, n, alpha, A.data(), m, m-1 ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', m, n, A.data(), A_ref.data(), m ); +} + +// When lda_out < m, with trans == 'r' +TYPED_TEST(imatcopy_IIT_ERS, invalid_lda_out_conjugate) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Having different values for m and n + gtint_t m = 10; + gtint_t n = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( 'r', m, n, alpha, nullptr, m, m-1 ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid lda for the operation. + imatcopy( 'r', m, n, alpha, A.data(), m, m-1 ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', m, n, A.data(), A_ref.data(), m ); +} + +// When lda_out < m, with trans == 't' +TYPED_TEST(imatcopy_IIT_ERS, invalid_lda_out_transpose) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Having different values for m and n + gtint_t m = 10; + gtint_t n = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( 't', m, n, alpha, nullptr, m, n-1 ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid lda for the operation. + imatcopy( 't', m, n, alpha, A.data(), m, n-1 ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', m, n, A.data(), A_ref.data(), m ); +} + +// When lda_out < m, with trans == 'c' +TYPED_TEST(imatcopy_IIT_ERS, invalid_lda_out_conjugate_transpose) +{ + using T = TypeParam; + T alpha = T{2.3}; + + // Having different values for m and n + gtint_t m = 10; + gtint_t n = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + imatcopy( 'c', m, n, alpha, nullptr, m, n-1 ); + + // Defining the A matrix with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of A are not modified. + std::vector A_ref(A); + + // Call imatcopy with a invalid lda for the operation. + imatcopy( 'c', m, n, alpha, A.data(), m, n-1 ); + // Use bitwise comparison (no threshold). + computediff( "A", 'c', m, n, A.data(), A_ref.data(), m ); +} +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/simatcopy_evt.cpp b/gtestsuite/testsuite/extension/imatcopy/simatcopy_evt.cpp new file mode 100644 index 0000000000..255cf89140 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/simatcopy_evt.cpp @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class simatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(simatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( simatcopyEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_imatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for simatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + simatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); + +// EVT testing for simatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + simatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(0.0f), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/simatcopy_generic.cpp b/gtestsuite/testsuite/extension/imatcopy/simatcopy_generic.cpp new file mode 100644 index 0000000000..91d0110717 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/simatcopy_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class simatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(simatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( simatcopyGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_in_inc for A + gtint_t lda_in_inc = std::get<5>(GetParam()); + // ldb_out_inc for A + gtint_t lda_out_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_imatcopy( storage, trans, m, n, alpha, lda_in_inc, lda_out_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of simatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + simatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'c', 'r'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda_in + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of lda_out + ::testing::Values(false, true) // is_memory_test + ), + ::imatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/test_imatcopy.h b/gtestsuite/testsuite/extension/imatcopy/test_imatcopy.h new file mode 100644 index 0000000000..3eec57f469 --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/test_imatcopy.h @@ -0,0 +1,203 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "imatcopy.h" +#include "extension/ref_imatcopy.h" +#include "inc/check_error.h" + +/** + * @brief Generic test body for imatcopy operation. + */ + +template +static void test_imatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, gtint_t lda_in_inc, gtint_t lda_out_inc, + double thresh, bool is_memory_test = false, bool is_nan_inf_test = false, T exval = T{0.0} ) +{ + // Set an alternative trans value that corresponds to only + // whether the A matrix(output) should be mxn or nxm(only transposing) + char A_out_trans; + A_out_trans = ( ( trans == 'n' ) || ( trans == 'r' ) )? 'n' : 't'; + + // Compute the leading dimensions of A(input) and A(output). + gtint_t lda_in = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_in_inc ); + gtint_t lda_out = testinghelpers::get_leading_dimension( storage, A_out_trans, m, n, lda_out_inc ); + + // Compute sizes of A(input) and A(output), in bytes + gtint_t size_a_in = testinghelpers::matsize( storage, 'n', m, n, lda_in ) * sizeof( T ); + gtint_t size_a_out = testinghelpers::matsize( storage, A_out_trans, m, n, lda_out ) * sizeof( T ); + + // A has to allocated the maximum of input and output sizes, for API compatibility + gtint_t size_a = (std::max)( size_a_in, size_a_out ); + + // Create the objects for the input and output operands + // The API does not expect the memory to be aligned + testinghelpers::ProtectedBuffer A_buf( size_a, false, is_memory_test ); + testinghelpers::ProtectedBuffer A_ref_buf( size_a, false, false ); + + // Pointers to access the memory chunks + T *A, *A_ref; + + // Acquire the first set of greenzones for A and A_ref + A = ( T* )A_buf.greenzone_1; + A_ref = ( T* )A_ref_buf.greenzone_1; // For A_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, storage, m, n, A, 'n', lda_in ); + + if( is_nan_inf_test ) + { + gtint_t rand_m = rand() % m; + gtint_t rand_n = rand() % n; + gtint_t idx = ( storage == 'c' || storage == 'C' )? ( rand_m + rand_n * lda_in ) : ( rand_n + rand_m * lda_in ); + + A[idx] = exval; + } + + // Copying the contents of A to A_ref + memcpy( A_ref, A, size_a ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the API. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + imatcopy( trans, m, n, alpha, A, lda_in, lda_out ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + A = ( T* )A_buf.greenzone_2; + + // Copy the data for A accordingly + // NOTE : The object for A will have acquired enough memory + // such that the greenzones in each do not overlap. + memcpy( A, A_ref, size_a ); + + // Call the API, to check with the second redzone. + imatcopy( trans, m, n, alpha, A, lda_in, lda_out ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_imatcopy( storage, trans, m, n, alpha, A_ref, lda_in, lda_out ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + + if( A_out_trans == 'n' ) + computediff( "A", storage, m, n, A, A_ref, lda_out, thresh, is_nan_inf_test ); + else + computediff( "A", storage, n, m, A, A_ref, lda_out, thresh, is_nan_inf_test ); + +} + +// Test-case logger : Used to print the test-case details based on parameters +template +class imatcopyGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t ldb_inc = std::get<6>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + char mat_trans = ( ( trans == 'n' ) || ( trans == 'r' ) )? 'n' : 't'; + gtint_t lda_in = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t lda_out = testinghelpers::get_leading_dimension( storage, mat_trans, m, n, ldb_inc ); + str_name += "_lda_in_" + std::to_string(lda_in); + str_name += "_lda_out_" + std::to_string(lda_out); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +template +class imatcopyEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t ldb_inc = std::get<6>(str.param); + T exval = std::get<7>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name = str_name + "_A_exval_" + testinghelpers::get_value_string(exval); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trans, m, n, ldb_inc ); + str_name += "_lda" + std::to_string(lda); + str_name += "_ldb" + std::to_string(ldb); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/extension/imatcopy/zimatcopy_evt.cpp b/gtestsuite/testsuite/extension/imatcopy/zimatcopy_evt.cpp new file mode 100644 index 0000000000..661f366ade --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/zimatcopy_evt.cpp @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class zimatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zimatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( zimatcopyEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_imatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for zimatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + zimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{1.0, 0.0}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); + +// EVT testing for zimatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + zimatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(dcomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::imatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/imatcopy/zimatcopy_generic.cpp b/gtestsuite/testsuite/extension/imatcopy/zimatcopy_generic.cpp new file mode 100644 index 0000000000..35a354c29d --- /dev/null +++ b/gtestsuite/testsuite/extension/imatcopy/zimatcopy_generic.cpp @@ -0,0 +1,106 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_imatcopy.h" + +class zimatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zimatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( zimatcopyGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_in_inc for A + gtint_t lda_in_inc = std::get<5>(GetParam()); + // ldb_out_inc for A + gtint_t lda_out_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_imatcopy( storage, trans, m, n, alpha, lda_in_inc, lda_out_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of zimatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + zimatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{-3.1, 1.7}, + dcomplex{1.0, 0.0}, dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::imatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/comatcopy_evt.cpp b/gtestsuite/testsuite/extension/omatcopy/comatcopy_evt.cpp new file mode 100644 index 0000000000..85c841aaaf --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/comatcopy_evt.cpp @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class comatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(comatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( comatcopyEVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for comatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + comatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{1.0, 0.0}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); + +// EVT testing for comatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + comatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(scomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/comatcopy_generic.cpp b/gtestsuite/testsuite/extension/omatcopy/comatcopy_generic.cpp new file mode 100644 index 0000000000..022446b67e --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/comatcopy_generic.cpp @@ -0,0 +1,106 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class comatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(comatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( comatcopyGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of comatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + comatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{-3.1, 1.7}, + scomplex{1.0, 0.0}, scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/domatcopy_evt.cpp b/gtestsuite/testsuite/extension/omatcopy/domatcopy_evt.cpp new file mode 100644 index 0000000000..5556db7815 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/domatcopy_evt.cpp @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class domatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(domatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( domatcopyEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for domatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + domatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); + +// EVT testing for domatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + domatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(0.0), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/domatcopy_generic.cpp b/gtestsuite/testsuite/extension/omatcopy/domatcopy_generic.cpp new file mode 100644 index 0000000000..6a3eb7a4f0 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/domatcopy_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class domatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(domatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( domatcopyGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of domatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + domatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/omatcopy.h b/gtestsuite/testsuite/extension/omatcopy/omatcopy.h new file mode 100644 index 0000000000..56792b5e8f --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/omatcopy.h @@ -0,0 +1,125 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +/** + * @brief Performs the operation: + * B := alpha * op(A), + * where op(A) could be A, A(transpose), A(conjugate), A(conjugate-transpose) + * @param[in] m number of rows in A, number of rows/columns in B + * @param[in] n number of columns in A, number of columns/rows in B + * @param[in] alpha scalar + * @param[in] A pointer which points to the first element of A matrix + * @param[in] lda leading dimension of A matrix + * @param[in, out] B pointer which points to the first element of B matrix + * @param[in] ldb leading dimension of B matrix + */ + +template +static void omatcopy_( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda, T* B, gtint_t ldb ) +{ + if constexpr (std::is_same::value) + somatcopy_( &trans, &m, &n, (const float *)&alpha, A, &lda, B, &ldb ); + else if constexpr (std::is_same::value) + domatcopy_( &trans, &m, &n, (const double *)&alpha, A, &lda, B, &ldb ); + else if constexpr (std::is_same::value) + comatcopy_( &trans, &m, &n, (const scomplex *)&alpha, A, &lda, B, &ldb ); + else if constexpr (std::is_same::value) + zomatcopy_( &trans, &m, &n, (const dcomplex *)&alpha, A, &lda, B, &ldb ); + else + throw std::runtime_error("Error in testsuite/extension/omatcopy.h: Invalid typename in omatcopy_()."); +} + +template +static void omatcopy( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda, T* B, gtint_t ldb ) +{ +#ifdef TEST_UPPERCASE_ARGS + trans = static_cast(std::toupper(static_cast(trans))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char trans_cpy = trans; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + + // Create copy of input arrays so we can check that they are not altered. + T* A_cpy = nullptr; + gtint_t size_A = testinghelpers::matsize( 'c', 'n', m, n, lda ); + + if (A && size_A > 0) + { + A_cpy = new T[size_A]; + memcpy( A_cpy, A, size_A * sizeof( T ) ); + } +#endif + +#ifdef TEST_BLAS_LIKE + omatcopy_( trans, m, n, alpha, A, lda, B, ldb ); +#else + throw std::runtime_error("Error in testsuite/extension/omatcopy.h: No interfaces are set to be tested."); +#endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "trans", trans, trans_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (A && size_A > 0) + { + computediff( "A", 'c', m, n, A, A_cpy, lda, true ); + delete[] A_cpy; + } +#endif +} + diff --git a/gtestsuite/testsuite/extension/omatcopy/omatcopy_IIT_ERS.cpp b/gtestsuite/testsuite/extension/omatcopy/omatcopy_IIT_ERS.cpp new file mode 100644 index 0000000000..b0c47ef9cb --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/omatcopy_IIT_ERS.cpp @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class omatcopy_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(omatcopy_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) + +/* + Incorrect Input Testing(IIT) + + The exceptions get triggered in the following cases: + 1. When TRANS != 'n' || TRANS != 't' || TRANS != 'c' || TRANS != 'r' + 2. When m < 0 + 3. When n < 0 + 4. When lda < max(1, m). + 5. When ldb < max(1, thresh), thresh set based on TRANS value +*/ + +// When TRANS is invalid +TYPED_TEST(omatcopy_IIT_ERS, invalid_transa) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( 'Q', M, N, alpha, nullptr, LDA, nullptr, LDB); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid value for TRANS value for the operation. + omatcopy( 'Q', M, N, alpha, A.data(), LDA, B.data(), LDB); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When m < 0 +TYPED_TEST(omatcopy_IIT_ERS, m_lt_zero) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( TRANS, -1, N, alpha, nullptr, LDA, nullptr, LDB); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid m for the operation. + omatcopy( TRANS, -1, N, alpha, A.data(), LDA, B.data(), LDB); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When n < 0 +TYPED_TEST(omatcopy_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( TRANS, M, -1, alpha, nullptr, LDA, nullptr, LDB); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid n for the operation. + omatcopy( TRANS, M, -1, alpha, A.data(), LDA, B.data(), LDB); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When lda < m +TYPED_TEST(omatcopy_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( 'n', m, n, alpha, nullptr, m - 1, nullptr, m); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid lda for the operation. + omatcopy( 'n', m, n, alpha, A.data(), m - 1, B.data(), m); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When ldb < m, with trans == 'n' +TYPED_TEST(omatcopy_IIT_ERS, invalid_ldb_no_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'n'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( trans, m, n, alpha, nullptr, m, nullptr, m - 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid ldb for the operation. + omatcopy( trans, m, n, alpha, A.data(), m, B.data(), m - 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When ldb < m, with trans == 'r' +TYPED_TEST(omatcopy_IIT_ERS, invalid_ldb_conjugate) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'r'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( trans, m, n, alpha, nullptr, m, nullptr, m - 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid ldb for the operation. + omatcopy( trans, m, n, alpha, A.data(), m, B.data(), m - 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When ldb < m, with trans == 't' +TYPED_TEST(omatcopy_IIT_ERS, invalid_ldb_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 't'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( trans, m, n, alpha, nullptr, m, nullptr, n - 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 't', m, n, n ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid ldb for the operation. + omatcopy( trans, m, n, alpha, A.data(), m, B.data(), n - 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', n, m, B.data(), B_ref.data(), n ); +} + +// When ldb < m, with trans == 'c' +TYPED_TEST(omatcopy_IIT_ERS, invalid_ldb_conjugate_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'c'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy( trans, m, n, alpha, nullptr, m, nullptr, n - 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 't', m, n, n ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY with a invalid ldb for the operation. + omatcopy( trans, m, n, alpha, A.data(), m, B.data(), n - 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', n, m, B.data(), B_ref.data(), n ); +} +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/somatcopy_evt.cpp b/gtestsuite/testsuite/extension/omatcopy/somatcopy_evt.cpp new file mode 100644 index 0000000000..5ff4b59a0d --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/somatcopy_evt.cpp @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class somatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(somatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( somatcopyEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for somatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + somatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); + +// EVT testing for somatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + somatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(0.0f), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/somatcopy_generic.cpp b/gtestsuite/testsuite/extension/omatcopy/somatcopy_generic.cpp new file mode 100644 index 0000000000..b74fab5d1f --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/somatcopy_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class somatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(somatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( somatcopyGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of somatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + somatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/test_omatcopy.h b/gtestsuite/testsuite/extension/omatcopy/test_omatcopy.h new file mode 100644 index 0000000000..6291cd40b1 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/test_omatcopy.h @@ -0,0 +1,204 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "omatcopy.h" +#include "extension/ref_omatcopy.h" +#include "inc/check_error.h" +#include + +/** + * @brief Generic test body for omatcopy operation. + */ + +template +static void test_omatcopy( char storage, char trans, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, + double thresh, bool is_memory_test = false, bool is_nan_inf_test = false, T exval = T{0.0} ) +{ + // Set an alternative trans value that corresponds to only + // whether the B matrix should be mxn or nxm(only transposing) + char B_trans; + B_trans = ( ( trans == 'n' ) || ( trans == 'r' ) )? 'n' : 't'; + + // Compute the leading dimensions of A and B. + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, B_trans, m, n, ldb_inc ); + + // Compute sizes of A and B, in bytes + gtint_t size_a = testinghelpers::matsize( storage, 'n', m, n, lda ) * sizeof( T ); + gtint_t size_b = testinghelpers::matsize( storage, B_trans, m, n, ldb ) * sizeof( T ); + + // Create the objects for the input and output operands + // The API does not expect the memory to be aligned + testinghelpers::ProtectedBuffer A_buf( size_a, false, is_memory_test ); + testinghelpers::ProtectedBuffer B_buf( size_b, false, is_memory_test ); + testinghelpers::ProtectedBuffer B_ref_buf( size_b, false, false ); + + // Pointers to access the memory chunks + T *A, *B, *B_ref; + + // Acquire the first set of greenzones for A and B + A = ( T* )A_buf.greenzone_1; + B = ( T* )B_buf.greenzone_1; + B_ref = ( T* )B_ref_buf.greenzone_1; // For B_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, storage, m, n, A, 'n', lda ); + testinghelpers::datagenerators::randomgenerators( -10, 10, storage, m, n, B, B_trans, ldb ); + + if( is_nan_inf_test ) + { + gtint_t rand_m = rand() % m; + gtint_t rand_n = rand() % n; + gtint_t idx = ( storage == 'c' || storage == 'C' )? ( rand_m + rand_n * lda ) : ( rand_n + rand_m * lda ); + + A[idx] = exval; + } + // Copying the contents of B to B_ref + memcpy( B_ref, B, size_b ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the API. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + omatcopy( trans, m, n, alpha, A, lda, B, ldb); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + A = ( T* )A_buf.greenzone_2; + B = ( T* )B_buf.greenzone_2; + + // Copy the data for A and B accordingly + // NOTE : The objects for A and B will have acquired enough memory + // such that the greenzones in each do not overlap. + memcpy( A, A_buf.greenzone_1, size_a ); + memcpy( B, B_buf.greenzone_1, size_b ); + + // Call the API, to check with the second redzone. + omatcopy( trans, m, n, alpha, A, lda, B, ldb); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_omatcopy( storage, trans, m, n, alpha, A, lda, B_ref, ldb ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + + if( B_trans == 'n' ) + computediff( "B", storage, m, n, B, B_ref, ldb, thresh, is_nan_inf_test ); + else + computediff( "B", storage, n, m, B, B_ref, ldb, thresh, is_nan_inf_test ); + +} + +// Test-case logger : Used to print the test-case details based on parameters +template +class omatcopyGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t ldb_inc = std::get<6>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trans, m, n, ldb_inc ); + str_name += "_lda" + std::to_string(lda); + str_name += "_ldb" + std::to_string(ldb); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +template +class omatcopyEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t ldb_inc = std::get<6>(str.param); + T exval = std::get<7>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name = str_name + "_A_exval_" + testinghelpers::get_value_string(exval); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trans, m, n, ldb_inc ); + str_name += "_lda" + std::to_string(lda); + str_name += "_ldb" + std::to_string(ldb); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/extension/omatcopy/zomatcopy_evt.cpp b/gtestsuite/testsuite/extension/omatcopy/zomatcopy_evt.cpp new file mode 100644 index 0000000000..a4aa1f5495 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/zomatcopy_evt.cpp @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class zomatcopyEVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zomatcopyEVT); + +// Tests using random numbers as vector elements. +TEST_P( zomatcopyEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // exval + T exval = std::get<7>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for zomatcopy, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + zomatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{1.0, 0.0}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); + +// EVT testing for zomatcopy, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + zomatcopyEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(dcomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::omatcopyEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy/zomatcopy_generic.cpp b/gtestsuite/testsuite/extension/omatcopy/zomatcopy_generic.cpp new file mode 100644 index 0000000000..41409094a9 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy/zomatcopy_generic.cpp @@ -0,0 +1,106 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy.h" + +class zomatcopyGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zomatcopyGeneric); + +// Tests using random numbers as vector elements. +TEST_P( zomatcopyGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy( storage, trans, m, n, alpha, lda_inc, ldb_inc, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && (defined(REF_IS_MKL) || defined(REF_IS_OPENBLAS)) +// Black box testing for generic and main use of zomatcopy. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + zomatcopyGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{-3.1, 1.7}, + dcomplex{1.0, 0.0}, dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopyGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_evt.cpp b/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_evt.cpp new file mode 100644 index 0000000000..e4862e4311 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_evt.cpp @@ -0,0 +1,147 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class comatcopy2EVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(comatcopy2EVT); + +// Tests using random numbers as vector elements. +TEST_P( comatcopy2EVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // exval + T exval = std::get<9>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for comatcopy2, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + comatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{1.0, 0.0}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); + +// EVT testing for comatcopy2, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + comatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{AOCL_INF, 0.0}, scomplex{0.0, -AOCL_INF}, + scomplex{0.0, AOCL_NAN}, scomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(scomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_generic.cpp b/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_generic.cpp new file mode 100644 index 0000000000..4e0b20806f --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/comatcopy2_generic.cpp @@ -0,0 +1,114 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class comatcopy2Generic : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(comatcopy2Generic); + +// Tests using random numbers as vector elements. +TEST_P( comatcopy2Generic, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<9>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) +// Black box testing for generic and main use of comatcopy2. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + comatcopy2Generic, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(scomplex{2.3, -3.5}, scomplex{-3.1, 1.7}, + scomplex{1.0, 0.0}, scomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopy2GenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_evt.cpp b/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_evt.cpp new file mode 100644 index 0000000000..cbff6377f4 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_evt.cpp @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class domatcopy2EVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(domatcopy2EVT); + +// Tests using random numbers as vector elements. +TEST_P( domatcopy2EVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // exval + T exval = std::get<9>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for domatcopy2, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + domatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); + +// EVT testing for domatcopy2, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + domatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(0.0), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_generic.cpp b/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_generic.cpp new file mode 100644 index 0000000000..3869e6a83e --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/domatcopy2_generic.cpp @@ -0,0 +1,113 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class domatcopy2Generic : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(domatcopy2Generic); + +// Tests using random numbers as vector elements. +TEST_P( domatcopy2Generic, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<9>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) +// Black box testing for generic and main use of domatcopy2. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + domatcopy2Generic, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0, -3.0, 1.0, 0.0), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopy2GenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/omatcopy2.h b/gtestsuite/testsuite/extension/omatcopy2/omatcopy2.h new file mode 100644 index 0000000000..4ff6c226ee --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/omatcopy2.h @@ -0,0 +1,130 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +/** + * @brief Performs the operation: + * B := alpha * op(A), + * where op(A) could be A, A(transpose), A(conjugate), A(conjugate-transpose) + * @param[in] m number of rows in A, number of rows/columns in B + * @param[in] n number of columns in A, number of columns/rows in B + * @param[in] alpha scalar + * @param[in] A pointer which points to the first element of A matrix + * @param[in] lda leading dimension of A matrix + * @param[in] stridea stride between two "continuous" elements in A + * @param[in, out] B pointer which points to the first element of B matrix + * @param[in] ldb leading dimension of B matrix + * @param[in] strideb stride between two "continuous" elements in B + */ + +template +static void omatcopy2_( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda, gtint_t stridea, T* B, gtint_t ldb, gtint_t strideb ) +{ + if constexpr (std::is_same::value) + somatcopy2_( &trans, &m, &n, (const float *)&alpha, A, &lda, &stridea, B, &ldb, &strideb ); + else if constexpr (std::is_same::value) + domatcopy2_( &trans, &m, &n, (const double *)&alpha, A, &lda, &stridea, B, &ldb, &strideb ); + else if constexpr (std::is_same::value) + comatcopy2_( &trans, &m, &n, (const scomplex *)&alpha, A, &lda, &stridea, B, &ldb, &strideb ); + else if constexpr (std::is_same::value) + zomatcopy2_( &trans, &m, &n, (const dcomplex *)&alpha, A, &lda, &stridea, B, &ldb, &strideb ); + else + throw std::runtime_error("Error in testsuite/extension/omatcopy2.h: Invalid typename in omatcopy2_()."); +} + +template +static void omatcopy2( char trans, gtint_t m, gtint_t n, T alpha, T* A, gtint_t lda, gtint_t stridea, T* B, gtint_t ldb, gtint_t strideb ) +{ +#ifdef TEST_UPPERCASE_ARGS + trans = static_cast(std::toupper(static_cast(trans))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char trans_cpy = trans; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t stridea_cpy = stridea; + gtint_t ldb_cpy = ldb; + gtint_t strideb_cpy = strideb; + + // Create copy of input arrays so we can check that they are not altered. + T* A_cpy = nullptr; + gtint_t size_A = testinghelpers::matsize( 'c', trans, m, n, lda ); + if (A && size_A > 0) + { + A_cpy = new T[size_A]; + memcpy( A_cpy, A, size_A * sizeof( T ) ); + } +#endif + +#ifdef TEST_BLAS_LIKE + omatcopy2_( trans, m, n, alpha, A, lda, stridea, B, ldb, strideb ); +#else + throw std::runtime_error("Error in testsuite/extension/omatcopy2.h: No interfaces are set to be tested."); +#endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "trans", trans, trans_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "stridea", stridea, stridea_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + computediff( "strideb", strideb, strideb_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (A && size_A > 0) + { + computediff( "A", 'c', m, n, A, A_cpy, lda, true ); + delete[] A_cpy; + } +#endif +} + diff --git a/gtestsuite/testsuite/extension/omatcopy2/omatcopy2_IIT_ERS.cpp b/gtestsuite/testsuite/extension/omatcopy2/omatcopy2_IIT_ERS.cpp new file mode 100644 index 0000000000..0b7d9c1089 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/omatcopy2_IIT_ERS.cpp @@ -0,0 +1,316 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class omatcopy2_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(omatcopy2_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) + +/* + Incorrect Input Testing(IIT) + + The exceptions get triggered in the following cases: + 1. When TRANS != 'n' || TRANS != 't' || TRANS != 'c' || TRANS != 'r' + 2. When m < 0 + 3. When n < 0 + 4. When lda < max(1, m). + 5. When stridea < 1. + 6. When ldb < max(1, thresh), thresh set based on TRANS value + 7. When strideb < 1. +*/ + +// When TRANS is invalid +TYPED_TEST(omatcopy2_IIT_ERS, invalid_transa) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( 'Q', M, N, alpha, nullptr, LDA, 1, nullptr, LDB, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid value for TRANS value for the operation. + omatcopy2( 'Q', M, N, alpha, A.data(), LDA, 1, B.data(), LDB, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When m < 0 +TYPED_TEST(omatcopy2_IIT_ERS, m_lt_zero) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( TRANS, -1, N, alpha, nullptr, LDA, 1, nullptr, LDB, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid m for the operation. + omatcopy2( TRANS, -1, N, alpha, A.data(), LDA, 1, B.data(), LDB, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When n < 0 +TYPED_TEST(omatcopy2_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( TRANS, M, -1, alpha, nullptr, LDA, 1, nullptr, LDB, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid n for the operation. + omatcopy2( TRANS, M, -1, alpha, A.data(), LDA, 1, B.data(), LDB, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When lda < m +TYPED_TEST(omatcopy2_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( 'n', m, n, alpha, nullptr, m - 1, 1, nullptr, m, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid lda for the operation. + omatcopy2( 'n', m, n, alpha, A.data(), m - 1, 1, B.data(), m, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When stridea < 1 +TYPED_TEST(omatcopy2_IIT_ERS, invalid_stridea) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( TRANS, M, N, alpha, nullptr, LDA, 0, nullptr, LDB, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid n for the operation. + omatcopy2( TRANS, M, N, alpha, A.data(), LDA, 0, B.data(), LDB, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} + +// When ldb < m, with trans == 'n' +TYPED_TEST(omatcopy2_IIT_ERS, invalid_ldb_no_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'n'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( trans, m, n, alpha, nullptr, m, 1, nullptr, m - 1, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid ldb for the operation. + omatcopy2( trans, m, n, alpha, A.data(), m, 1, B.data(), m - 1, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When ldb < m, with trans == 'r' +TYPED_TEST(omatcopy2_IIT_ERS, invalid_ldb_conjugate) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'r'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( trans, m, n, alpha, nullptr, m, 1, nullptr, m - 1, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid ldb for the operation. + omatcopy2( trans, m, n, alpha, A.data(), m, 1, B.data(), m - 1, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', m, n, B.data(), B_ref.data(), m ); +} + +// When ldb < m, with trans == 't' +TYPED_TEST(omatcopy2_IIT_ERS, invalid_ldb_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 't'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( trans, m, n, alpha, nullptr, m, 1, nullptr, n - 1, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 't', m, n, n ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid ldb for the operation. + omatcopy2( trans, m, n, alpha, A.data(), m, 1, B.data(), n - 1, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', n, m, B.data(), B_ref.data(), n ); +} + +// When ldb < m, with trans == 'c' +TYPED_TEST(omatcopy2_IIT_ERS, invalid_ldb_conjugate_transpose) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Having different values for m and n + gtint_t m = 5; + gtint_t n = 10; + char trans = 'c'; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( trans, m, n, alpha, nullptr, m, 1, nullptr, n - 1, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', m, n, m ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 't', m, n, n ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid ldb for the operation. + omatcopy2( trans, m, n, alpha, A.data(), m, 1, B.data(), n - 1, 1 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', n, m, B.data(), B_ref.data(), n ); +} + +// When strideb < 1 +TYPED_TEST(omatcopy2_IIT_ERS, invalid_strideb) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + omatcopy2( TRANS, M, N, alpha, nullptr, LDA, 1, nullptr, LDB, 0 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the A and B matrices with values for debugging purposes + std::vector A = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDA ); + std::vector B = testinghelpers::get_random_matrix(-10, 10, 'c', 'n', M, N, LDB ); + // Copy so that we check that the elements of B are not modified. + std::vector B_ref(B); + + // Call OMATCOPY2 with a invalid n for the operation. + omatcopy2( TRANS, M, N, alpha, A.data(), LDA, 1, B.data(), LDB, 0 ); + // Use bitwise comparison (no threshold). + computediff( "B", 'c', M, N, B.data(), B_ref.data(), LDB ); +} +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_evt.cpp b/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_evt.cpp new file mode 100644 index 0000000000..680db3bc98 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_evt.cpp @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class somatcopy2EVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(somatcopy2EVT); + +// Tests using random numbers as vector elements. +TEST_P( somatcopy2EVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // exval + T exval = std::get<9>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha)) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for somatcopy2, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + somatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); + +// EVT testing for somatcopy2, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + somatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(AOCL_NAN, AOCL_INF, -AOCL_INF), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(0.0f), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_generic.cpp b/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_generic.cpp new file mode 100644 index 0000000000..38d425840f --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/somatcopy2_generic.cpp @@ -0,0 +1,113 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class somatcopy2Generic : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(somatcopy2Generic); + +// Tests using random numbers as vector elements. +TEST_P( somatcopy2Generic, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<9>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) +// Black box testing for generic and main use of somatcopy2. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + somatcopy2Generic, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(2.0f, -3.0f, 1.0f, 0.0f), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopy2GenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/test_omatcopy2.h b/gtestsuite/testsuite/extension/omatcopy2/test_omatcopy2.h new file mode 100644 index 0000000000..0287d2848b --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/test_omatcopy2.h @@ -0,0 +1,213 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "omatcopy2.h" +#include "extension/ref_omatcopy2.h" +#include "inc/check_error.h" +#include + +/** + * @brief Generic test body for omatcopy2 operation. + */ + +template +static void test_omatcopy2( char storage, char trans, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t stridea, gtint_t ldb_inc, + gtint_t strideb, double thresh, bool is_memory_test = false, bool is_nan_inf_test = false, T exval = T{0.0} ) +{ + // Set an alternative trans value that corresponds to only + // whether the B matrix should be mxn or nxm(only transposing) + char B_trans; + B_trans = ( ( trans == 'n' ) || ( trans == 'r' ) )? 'n' : 't'; + + // Compute the leading dimensions of A and B. + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc, stridea ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, B_trans, m, n, ldb_inc, strideb ); + + // Compute sizes of A and B, in bytes + gtint_t size_a = testinghelpers::matsize( storage, 'n', m, n, lda ) * sizeof( T ); + gtint_t size_b = testinghelpers::matsize( storage, B_trans, m, n, ldb ) * sizeof( T ); + + // Create the objects for the input and output operands + // The API does not expect the memory to be aligned + testinghelpers::ProtectedBuffer A_buf( size_a, false, is_memory_test ); + testinghelpers::ProtectedBuffer B_buf( size_b, false, is_memory_test ); + testinghelpers::ProtectedBuffer B_ref_buf( size_b, false, false ); + + // Pointers to access the memory chunks + T *A, *B, *B_ref; + + // Acquire the first set of greenzones for A and B + A = ( T* )A_buf.greenzone_1; + B = ( T* )B_buf.greenzone_1; + B_ref = ( T* )B_ref_buf.greenzone_1; // For B_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, storage, m, n, A, 'n', lda, stridea ); + testinghelpers::datagenerators::randomgenerators( -10, 10, storage, m, n, B, B_trans, ldb, strideb ); + + if( is_nan_inf_test ) + { + gtint_t rand_m = rand() % m; + gtint_t rand_n = rand() % n; + gtint_t idx = ( storage == 'c' || storage == 'C' )? ( rand_m * stridea + rand_n * lda ) : ( rand_n * stridea + rand_m * lda ); + + A[idx] = exval; + } + // Copying the contents of B to B_ref + memcpy( B_ref, B, size_b ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the API. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + omatcopy2( trans, m, n, alpha, A, lda, stridea, B, ldb, strideb ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + A = ( T* )A_buf.greenzone_2; + B = ( T* )B_buf.greenzone_2; + + // Copy the data for A and B accordingly + // NOTE : The objects for A and B will have acquired enough memory + // such that the greenzones in each do not overlap. + memcpy( A, A_buf.greenzone_1, size_a ); + memcpy( B, B_buf.greenzone_1, size_b ); + + // Call the API, to check with the second redzone. + omatcopy2( trans, m, n, alpha, A, lda, stridea, B, ldb, strideb ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_omatcopy2( storage, trans, m, n, alpha, A, lda, stridea, B_ref, ldb, strideb ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + + if( B_trans == 'n' ) + computediff( "B", storage, m, n, B, B_ref, ldb, thresh, is_nan_inf_test ); + else + computediff( "B", storage, n, m, B, B_ref, ldb, thresh, is_nan_inf_test ); + +} + + +// Test-case logger : Used to print the test-case details based on parameters +template +class omatcopy2GenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t stridea = std::get<6>(str.param); + gtint_t ldb_inc = std::get<7>(str.param); + gtint_t strideb = std::get<8>(str.param); + bool is_memory_test = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trans, m, n, ldb_inc ); + str_name += "_lda" + std::to_string(lda); + str_name += "_stridea" + std::to_string(stridea); + str_name += "_ldb" + std::to_string(ldb); + str_name += "_strideb" + std::to_string(strideb); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +template +class comatcopy2EVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char trans = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t lda_inc = std::get<5>(str.param); + gtint_t stridea = std::get<6>(str.param); + gtint_t ldb_inc = std::get<7>(str.param); + gtint_t strideb = std::get<8>(str.param); + T exval = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_trans_" + std::string(&trans, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name = str_name + "_A_exval" + testinghelpers::get_value_string(exval); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trans, m, n, ldb_inc ); + str_name += "_lda" + std::to_string(lda); + str_name += "_stridea" + std::to_string(stridea); + str_name += "_ldb" + std::to_string(ldb); + str_name += "_stridea" + std::to_string(strideb); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_evt.cpp b/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_evt.cpp new file mode 100644 index 0000000000..47b38a0780 --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_evt.cpp @@ -0,0 +1,147 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class zomatcopy2EVT : + public ::testing::TestWithParam> {}; // is_nan_inf_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zomatcopy2EVT); + +// Tests using random numbers as vector elements. +TEST_P( zomatcopy2EVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // exval + T exval = std::get<9>(GetParam()); + // is_nan_inf_test + bool is_nan_inf_test = std::get<8>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) && !(std::isnan(alpha.real) || std::isnan(alpha.imag)) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + // Note: is_memory_test is passed as false(hard-coded), since memory tests are done in _generic.cpp files + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, false, is_nan_inf_test, exval ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +// EVT testing for zomatcopy2, with exception values in A matrix +INSTANTIATE_TEST_SUITE_P( + matrixA, + zomatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{1.0, 0.0}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); + +// EVT testing for zomatcopy2, with exception values in alpha +INSTANTIATE_TEST_SUITE_P( + alpha, + zomatcopy2EVT, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{AOCL_INF, 0.0}, dcomplex{0.0, -AOCL_INF}, + dcomplex{0.0, AOCL_NAN}, dcomplex{AOCL_NAN, AOCL_INF}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(17)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(dcomplex{0.0, 0.0}), // exval + ::testing::Values(true) // is_nan_inf_test + ), + ::comatcopy2EVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_generic.cpp b/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_generic.cpp new file mode 100644 index 0000000000..19fa29e49c --- /dev/null +++ b/gtestsuite/testsuite/extension/omatcopy2/zomatcopy2_generic.cpp @@ -0,0 +1,114 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_omatcopy2.h" + +class zomatcopy2Generic : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zomatcopy2Generic); + +// Tests using random numbers as vector elements. +TEST_P( zomatcopy2Generic, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the storage format of the input matrices + char storage = std::get<0>(GetParam()); + // denotes the trans value for the operation + char trans = std::get<1>(GetParam()); + // m dimension + gtint_t m = std::get<2>(GetParam()); + // n dimension + gtint_t n = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stridea + gtint_t stridea = std::get<6>(GetParam()); + // ldb_inc for B + gtint_t ldb_inc = std::get<7>(GetParam()); + // strideb + gtint_t strideb = std::get<8>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<9>(GetParam()); + + double thresh = 0.0; + // Set the threshold for the errors + if( ( alpha != testinghelpers::ZERO() || alpha != testinghelpers::ONE() ) ) + thresh = 3 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_omatcopy2( storage, trans, m, n, alpha, lda_inc, stridea, ldb_inc, strideb, thresh, is_memory_test ); +} + +#if defined(TEST_BLAS_LIKE) && defined(REF_IS_MKL) +// Black box testing for generic and main use of zomatcopy2. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + zomatcopy2Generic, + ::testing::Combine( + ::testing::Values('c'), // storage format(currently only for BLAS testing) + ::testing::Values('n', 't', 'r', 'c'), // trans(and/or conj) value + // 'n' - no-transpose, 't' - transpose + // 'r' - conjugate, 'c' - conjugate-transpose + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // m + ::testing::Values(gtint_t(10), gtint_t(55), gtint_t(243)), // n + ::testing::Values(dcomplex{2.3, -3.5}, dcomplex{-3.1, 1.7}, + dcomplex{1.0, 0.0}, dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of lda + ::testing::Values(gtint_t(1), gtint_t(3)), // stridea + ::testing::Values(gtint_t(0), gtint_t(25)), // increment of ldb + ::testing::Values(gtint_t(1), gtint_t(3)), // strideb + ::testing::Values(false, true) // is_memory_test + ), + ::omatcopy2GenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/inc/check_error.h b/gtestsuite/testsuite/inc/check_error.h index 4f6d848855..b5b24ffc21 100644 --- a/gtestsuite/testsuite/inc/check_error.h +++ b/gtestsuite/testsuite/inc/check_error.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -95,7 +95,7 @@ struct ComparisonHelper{ j(-11), binary_comparison(false), nan_inf_check(false) {}; - // Constructor for the generic case where theshold is used. + // Constructor for the generic case where threshold is used. ComparisonHelper(ObjType object_type, double threshold) : threshold(threshold), object_type(object_type), i(-11), @@ -121,10 +121,12 @@ testing::AssertionResult NumericalComparisonFPOnly(const char* blis_sol_char, } else { double error = testinghelpers::getError(blis_sol,ref_sol); - if (error < comp_helper.threshold) return testing::AssertionSuccess(); + if (error <= comp_helper.threshold) return testing::AssertionSuccess(); + using RT = typename testinghelpers::type_info::real_type; return testing::AssertionFailure() << error_message - << ", thesh = " << comp_helper.threshold - << ", error = " << error; + << ", thresh = " << comp_helper.threshold + << ", error = " << error + << " (" << error/std::numeric_limits::epsilon() << " * eps)"; } } @@ -225,17 +227,21 @@ testing::AssertionResult NumericalComparisonInf(const char* blis_sol_char, return testing::AssertionFailure() << error_message; } -// Comparisons that take into account the presence of NaNs and Infs: +// Comparisons that take into account the presence of NaNs and Infs, printing variable name: template::real_type> -testing::AssertionResult NumericalComparison(const char* blis_sol_char, +testing::AssertionResult NumericalComparison(const char* var_name_char, + const char* blis_sol_char, const char* ref_sol_char, const char* comp_helper_char, + std::string var_name, const T blis_sol, const T ref_sol, const ComparisonHelper comp_helper) { // Base error message used for scalar values - std::string error_message = blis_sol_char; + std::string error_message = var_name_char; + error_message += " = " + var_name + ", "; + error_message += blis_sol_char; error_message += " = " + testinghelpers::to_string(blis_sol) + ", "; error_message += ref_sol_char; error_message += " = " + testinghelpers::to_string(ref_sol); @@ -291,34 +297,34 @@ testing::AssertionResult NumericalComparison(const char* blis_sol_char, } /** - * Binary comparison of two scalars. + * Binary comparison of two scalars, printing variable name. */ template -void computediff( T blis_sol, T ref_sol, bool nan_inf_check = false ) +void computediff( std::string var_name, T blis_sol, T ref_sol, bool nan_inf_check = false ) { ComparisonHelper comp_helper(SCALAR); comp_helper.binary_comparison = true; comp_helper.nan_inf_check = nan_inf_check; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol, ref_sol, comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol, ref_sol, comp_helper); } /** - * Relative comparison of two scalars, using a threshold. + * Relative comparison of two scalars, using a threshold, printing variable name. */ template -void computediff( T blis_sol, T ref_sol, double thresh, bool nan_inf_check = false ) +void computediff( std::string var_name, T blis_sol, T ref_sol, double thresh, bool nan_inf_check = false ) { - ComparisonHelper comp_helper(SCALAR, thresh); + ComparisonHelper comp_helper(SCALAR, thresh); comp_helper.nan_inf_check = nan_inf_check; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol, ref_sol, comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol, ref_sol, comp_helper); } /** - * Binary comparison of two vectors with length n and increment inc. + * Binary comparison of two vectors with length n and increment inc, printing variable name. */ template -void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, bool nan_inf_check = false ) +void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, bool nan_inf_check = false ) { gtint_t abs_inc = std::abs(inc); ComparisonHelper comp_helper(VECTOR); @@ -330,24 +336,65 @@ void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, bool nan_inf_ for (gtint_t i = 0; i < n; i++) { comp_helper.i = i; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*abs_inc], ref_sol[i*abs_inc], comp_helper) << "inc = " << inc; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*abs_inc], ref_sol[i*abs_inc], comp_helper) << "inc = " << inc; // Go through elements that are part of the array that should not have been modified by the // call to a BLIS API. Use the bitwise comparison for this case. if (i < n-1) { for (gtint_t j = 1; j < abs_inc; j++) { - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*abs_inc + j], ref_sol[i*abs_inc + j], comp_helper) << "inc = " << inc << " This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*abs_inc + j], ref_sol[i*abs_inc + j], comp_helper) << "inc = " << inc << " This element is expected to not be modified."; + } + } + } +} + + +/** + * Binary comparison of two vectors with length n and increment inc, printing variable names. + */ +template +void computediff( gtint_t n, T *blis_x, T *blis_x_ref, T *blis_y, T *blis_y_ref, gtint_t incx, gtint_t incy, bool nan_inf_check = false ) +{ + gtint_t abs_incx = std::abs(incx); + gtint_t abs_incy = std::abs(incy); + int idx, idy; + ComparisonHelper comp_helper(VECTOR); + comp_helper.nan_inf_check = nan_inf_check; + comp_helper.binary_comparison = true; + + // In case inc is negative in a call to BLIS APIs, we just access it from the end to the beginning, + // so practically nothing changes. Access from beginning to end to optimize memory operations. + for (gtint_t i = 0; i < n; i++) + { + comp_helper.i = i; + idx = (incx > 0) ? (i * incx) : ( - ( n - i - 1 ) * incx ); + idy = (incy > 0) ? (i * incy) : ( - ( n - i - 1 ) * incy ); + ASSERT_PRED_FORMAT4(NumericalComparison, "x", blis_x[idx], blis_y_ref[idy], comp_helper) << "incx = " << incx ; + ASSERT_PRED_FORMAT4(NumericalComparison, "y", blis_y[idy], blis_x_ref[idx], comp_helper) << "incy = " << incy; // Go through elements that are part of the array that should not have been modified by the + // call to a BLIS API. Use the bitwise comparison for this case. + // Random generator fills vector with T{-1.2345e38} + if (i < n-1) + { + for (gtint_t j = 1; j < abs_incx; j++) + { + idx = (incx > 0) ? (i * incx) : ( - ( n - i - 1 ) * incx ); + ASSERT_PRED_FORMAT4(NumericalComparison, "x", blis_x[i*abs_incx + j], T{-1.2345e38}, comp_helper) << "incx = " << incx << " This element is expected to not be modified."; + } + for (gtint_t j = 1; j < abs_incy; j++) + { + idy = (incy > 0) ? (i * incy) : ( - ( n - i - 1 ) * incy ); + ASSERT_PRED_FORMAT4(NumericalComparison, "y", blis_y[i*abs_incy + j], T{-1.2345e38}, comp_helper) << "incy = " << incy << " This element is expected to not be modified."; } } } } /** - * Relative comparison of two vectors with length n and increment inc. + * Relative comparison of two vectors with length n and increment inc, printing variable name. */ template -void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, double thresh, bool nan_inf_check = false ) +void computediff( std::string var_name, gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, double thresh, bool nan_inf_check = false ) { gtint_t abs_inc = std::abs(inc); ComparisonHelper comp_helper(VECTOR, thresh); @@ -358,7 +405,7 @@ void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, double thresh for (gtint_t i = 0; i < n; i++) { comp_helper.i = i; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*abs_inc], ref_sol[i*abs_inc], comp_helper) << "inc = " << inc; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*abs_inc], ref_sol[i*abs_inc], comp_helper) << "inc = " << inc; // Go through elements that are part of the array that should not have been modified by the // call to a BLIS API. Use the bitwise comparison for this case. if (i < n-1) @@ -366,7 +413,7 @@ void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, double thresh for (gtint_t j = 1; j < abs_inc; j++) { comp_helper.binary_comparison = true; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*abs_inc + j], ref_sol[i*abs_inc + j], comp_helper) << "inc = " << inc << " This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*abs_inc + j], ref_sol[i*abs_inc + j], comp_helper) << "inc = " << inc << " This element is expected to not be modified."; } comp_helper.binary_comparison = false; } @@ -374,10 +421,10 @@ void computediff( gtint_t n, T *blis_sol, T *ref_sol, gtint_t inc, double thresh } /** - * Binary comparison of two matrices with dimensions m-by-n and leading dimension ld. + * Binary comparison of two matrices with dimensions m-by-n and leading dimension ld, printing variable name. */ template -void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gtint_t ld, bool nan_inf_check = false ) +void computediff(std::string var_name, char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gtint_t ld, bool nan_inf_check = false ) { gtint_t i,j; ComparisonHelper comp_helper(MATRIX); @@ -394,15 +441,15 @@ void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gt { comp_helper.i = i; comp_helper.j = j; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper); } // Now iterate through the rest of elements in memory space that are not part of the matrix, // so we use binary comparison to verify that are exactly the same as the reference. // Since to get create the data we use a copy to initialize the reference results, those // elements are expected to identical. - for (i = m; i < ld; i++) + for (i = (std::max)(m,0); i < ld; i++) { - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper) << "This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper) << "This element is expected to not be modified."; } } } @@ -417,25 +464,25 @@ void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gt { comp_helper.i = i; comp_helper.j = j; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper); } // Now iterate through the rest of elements in memory space that are not part of the matrix, // so we use binary comparison to verify that are exactly the same as the reference. // Since to get create the data we use a copy to initialize the reference results, those // elements are expected to identical. - for (j = n; j < ld; j++) + for (j = (std::max)(n,0); j < ld; j++) { - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper) << "This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper) << "This element is expected to not be modified."; } } } } /** - * Relative comparison of two matrices with dimensions m-by-n and leading dimension ld. + * Relative comparison of two matrices with dimensions m-by-n and leading dimension ld, printing variable name. */ template -void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gtint_t ld, double thresh, bool nan_inf_check = false ) +void computediff(std::string var_name, char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gtint_t ld, double thresh, bool nan_inf_check = false ) { gtint_t i,j; ComparisonHelper comp_helper(MATRIX, thresh); @@ -452,16 +499,16 @@ void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gt { comp_helper.i = i; comp_helper.j = j; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper); } // Now iterate through the rest of elements in memory space that are not part of the matrix, // so we use binary comparison to verify that are exactly the same as the reference. // Since to get create the data we use a copy to initialize the reference results, those // elements are expected to identical. comp_helper.binary_comparison = true; - for (i = m; i < ld; i++) + for (i = (std::max)(m,0); i < ld; i++) { - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper) << "This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i + j*ld], ref_sol[i + j*ld], comp_helper) << "This element is expected to not be modified."; } // Disable binary comparison before we go through the next column. comp_helper.binary_comparison = false; @@ -478,19 +525,63 @@ void computediff(char storage, gtint_t m, gtint_t n, T *blis_sol, T *ref_sol, gt { comp_helper.i = i; comp_helper.j = j; - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper); + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper); } // Now iterate through the rest of elements in memory space that are not part of the matrix, // so we use binary comparison to verify that are exactly the same as the reference. // Since to get create the data we use a copy to initialize the reference results, those // elements are expected to identical. comp_helper.binary_comparison = true; - for (j = n; j < ld; j++) + for (j = (std::max)(n,0); j < ld; j++) { - ASSERT_PRED_FORMAT3(NumericalComparison, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper) << "This element is expected to not be modified."; + ASSERT_PRED_FORMAT4(NumericalComparison, var_name, blis_sol[i*ld + j], ref_sol[i*ld + j], comp_helper) << "This element is expected to not be modified."; } // Disable binary comparison before we go through the next column. comp_helper.binary_comparison = false; } } } + +// Generic comparison of integer numbers, printing variable name: +template +testing::AssertionResult EqualityComparison(const char* var_name_char, + const char* blis_sol_char, + const char* ref_sol_char, + const char* comp_helper_char, + std::string var_name, + const T blis_sol, + const T ref_sol, + const ComparisonHelper comp_helper) +{ + // Base error message used for scalar values + std::string error_message = var_name_char; + error_message += " = " + var_name + ", "; + error_message += blis_sol_char; + error_message += " = " + testinghelpers::to_string(blis_sol) + ", "; + error_message += ref_sol_char; + error_message += " = " + testinghelpers::to_string(ref_sol); + + if (blis_sol == ref_sol) return testing::AssertionSuccess(); + return testing::AssertionFailure() << error_message; +} + +/** + * Comparison of two integers, printing variable name. + */ +template <> +inline void computediff( std::string var_name, gtint_t blis_sol, gtint_t ref_sol, bool nan_inf_check ) +{ + ComparisonHelper comp_helper(SCALAR); + ASSERT_PRED_FORMAT4(EqualityComparison, var_name, blis_sol, ref_sol, comp_helper); +} + +/** + * Comparison of two characters, printing variable name. + */ +template <> +inline void computediff( std::string var_name, char blis_sol, char ref_sol, bool nan_inf_check ) +{ + ComparisonHelper comp_helper(SCALAR); + ASSERT_PRED_FORMAT4(EqualityComparison, var_name, blis_sol, ref_sol, comp_helper); +} + diff --git a/gtestsuite/testsuite/level1/addv/addv.h b/gtestsuite/testsuite/level1/addv/addv.h index ed392dedc5..825ea014d3 100644 --- a/gtestsuite/testsuite/level1/addv/addv.h +++ b/gtestsuite/testsuite/level1/addv/addv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Computes @@ -66,12 +67,35 @@ static void typed_addv(char conj_x, gtint_t n, T* x, gtint_t incx, T* y, gtint_t else throw std::runtime_error("Error in testsuite/level1/addv.h: Invalid typename in typed_addv()."); } - template static void addv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/addv.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/addv.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS throw std::runtime_error("Error in testsuite/level1/addv.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED @@ -79,4 +103,25 @@ static void addv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) #else throw std::runtime_error("Error in testsuite/level1/addv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/addv/caddv_generic.cpp b/gtestsuite/testsuite/level1/addv/caddv_generic.cpp index fe72eee37c..19069bbc18 100644 --- a/gtestsuite/testsuite/level1/addv/caddv_generic.cpp +++ b/gtestsuite/testsuite/level1/addv/caddv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_addv.h" -class caddvGenericTest : +class caddvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caddvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caddvGeneric); -TEST_P( caddvGenericTest, RandomData ) +TEST_P( caddvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -57,7 +57,15 @@ TEST_P( caddvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite addv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +73,17 @@ TEST_P( caddvGenericTest, RandomData ) test_addv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class caddvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_caddv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - caddvGenericTest, + caddvGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::caddvGenericTestPrint() + ::addvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/addv/daddv_generic.cpp b/gtestsuite/testsuite/level1/addv/daddv_generic.cpp index 40ac621290..55e0ffc715 100644 --- a/gtestsuite/testsuite/level1/addv/daddv_generic.cpp +++ b/gtestsuite/testsuite/level1/addv/daddv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_addv.h" -class daddvGenericTest : +class daddvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daddvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daddvGeneric); -TEST_P( daddvGenericTest, RandomData ) +TEST_P( daddvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -57,7 +57,14 @@ TEST_P( daddvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite addv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +72,17 @@ TEST_P( daddvGenericTest, RandomData ) test_addv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class daddvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_daddv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - daddvGenericTest, + daddvGeneric, ::testing::Combine( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::daddvGenericTestPrint() + ::addvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/addv/saddv_generic.cpp b/gtestsuite/testsuite/level1/addv/saddv_generic.cpp index 8dbdd7e3ea..605f47dcc8 100644 --- a/gtestsuite/testsuite/level1/addv/saddv_generic.cpp +++ b/gtestsuite/testsuite/level1/addv/saddv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_addv.h" -class saddvGenericTest : +class saddvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saddvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saddvGeneric); -TEST_P( saddvGenericTest, RandomData ) +TEST_P( saddvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -57,7 +57,14 @@ TEST_P( saddvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite addv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +72,17 @@ TEST_P( saddvGenericTest, RandomData ) test_addv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class saddvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_saddv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - saddvGenericTest, + saddvGeneric, ::testing::Combine( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::saddvGenericTestPrint() + ::addvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/addv/test_addv.h b/gtestsuite/testsuite/level1/addv/test_addv.h index 25c93ac99e..d7c9f32453 100644 --- a/gtestsuite/testsuite/level1/addv/test_addv.h +++ b/gtestsuite/testsuite/level1/addv/test_addv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -66,5 +66,24 @@ void test_addv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } + +// Test-case logger : Used to print the test-case details based on parameters +class addvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/addv/zaddv_generic.cpp b/gtestsuite/testsuite/level1/addv/zaddv_generic.cpp index 7fde610664..979e05c421 100644 --- a/gtestsuite/testsuite/level1/addv/zaddv_generic.cpp +++ b/gtestsuite/testsuite/level1/addv/zaddv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_addv.h" -class ZAddvGenericTest : +class zaddvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ZAddvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaddvGeneric); -TEST_P( ZAddvGenericTest, RandomData ) +TEST_P( zaddvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -57,7 +57,15 @@ TEST_P( ZAddvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite addv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +73,17 @@ TEST_P( ZAddvGenericTest, RandomData ) test_addv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class ZAddvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_zaddv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ZAddvGenericTest, + zaddvGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::ZAddvGenericTestPrint() + ::addvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/amaxv/amaxv.h b/gtestsuite/testsuite/level1/amaxv/amaxv.h index 4479263e2b..01729a5c67 100644 --- a/gtestsuite/testsuite/level1/amaxv/amaxv.h +++ b/gtestsuite/testsuite/level1/amaxv/amaxv.h @@ -1,22 +1,22 @@ -/* + /* BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Finds the index of the first element that has the maximum absolute value. @@ -44,6 +45,7 @@ * @param[in] incx increment of x * * If n < 1 or incx <= 0, return 0. + * If n == 1, return 1(BLAS) or 0(CBLAS). */ template @@ -61,9 +63,25 @@ static gtint_t amaxv_(gtint_t n, T* x, gtint_t incx) { else throw std::runtime_error("Error in testsuite/level1/amaxv.h: Invalid typename in amaxv_()."); - // Since we are comparing against CBLAS which is 0-based and BLAS is 1-based, - // we need to use -1 here. - return (idx-1); + return idx; +} + +template +static gtint_t amaxv_blis_impl(gtint_t n, T* x, gtint_t incx) { + + gtint_t idx; + if constexpr (std::is_same::value) + idx = isamax_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + idx = idamax_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + idx = icamax_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + idx = izamax_blis_impl( &n, x, &incx ); + else + throw std::runtime_error("Error in testsuite/level1/amaxv.h: Invalid typename in amaxv_blis_impl()."); + + return idx; } template @@ -105,8 +123,37 @@ static gtint_t typed_amaxv(gtint_t n, T* x, gtint_t incx) template static gtint_t amaxv(gtint_t n, T* x, gtint_t incx) { -#ifdef TEST_BLAS - return amaxv_(n, x, incx); + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + +#ifdef TEST_BLAS_LIKE + // Since we would be comparing against CBLAS which is 0-based and BLAS + // which is 1-based, we need decrement the result of BLAS call by 1. + // Exception is IIT tests which return 0 in both BLAS and CBLAS. + + #ifdef TEST_BLAS + gtint_t idx = amaxv_(n, x, incx); + #elif TEST_BLAS_BLIS_IMPL + gtint_t idx = amaxv_blis_impl(n, x, incx); + #endif + if ( n < 1 || incx <= 0 ) + return idx; + else + return idx - 1; + #elif TEST_CBLAS return cblas_amaxv(n, x, incx); #elif TEST_BLIS_TYPED @@ -114,4 +161,23 @@ static gtint_t amaxv(gtint_t n, T* x, gtint_t incx) #else throw std::runtime_error("Error in testsuite/level1/amaxv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/amaxv/amaxv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/amaxv/amaxv_IIT_ERS.cpp new file mode 100644 index 0000000000..311e4baf23 --- /dev/null +++ b/gtestsuite/testsuite/level1/amaxv/amaxv_IIT_ERS.cpp @@ -0,0 +1,192 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_amaxv.h" +#include "level1/ref_amaxv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class amaxvIIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(amaxvIIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) +/* + + Early Return Scenarios(ERS) for BLAS/CBLAS compliance : + + The AMAX API is expected to return early in the following cases: + 1. When n < 1. + 2. When incx <= 0. + + The index returned in these cases is expected to be 0. + + Further, the API is expected to return early when: + 3. When n == 1. + + The index returned in this case is expected to be 1(BLAS) + or 0(CBLAS), but we handle all comparisons as if from CBLAS + with the conversion occurring in the amaxv.h header file. +*/ + +// n < 1, with non-unit stride +TYPED_TEST(amaxvIIT_ERS, n_lt_one_nonUnitStride) +{ + using T = TypeParam; + gtint_t n = 0; + gtint_t inc = 5; + gtint_t idx = 42; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + idx = amaxv( n, nullptr, inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking AMAXV with an invalid value of n. + idx = amaxv( n, x.data(), inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); +} + +// inc == 0, with non-unit stride +TYPED_TEST(amaxvIIT_ERS, incx_eq_zero) +{ + using T = TypeParam; + gtint_t inc = 0; + gtint_t idx = 42; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + idx = amaxv( N, nullptr, inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Invoking AMAXV with an invalid value of incx. + idx = amaxv( N, x.data(), inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); +} + +// n < 1, with unit stride +TYPED_TEST(amaxvIIT_ERS, n_lt_one_unitStride) +{ + using T = TypeParam; + gtint_t n = 0; + gtint_t unit_inc = 1; + gtint_t idx = 42; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + idx = amaxv( n, nullptr, unit_inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking AMAXV with an invalid value of n. + idx = amaxv( n, x.data(), unit_inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); +} + +// n == 1, with unit stride +TYPED_TEST(amaxvIIT_ERS, n_eq_one_unitStride) +{ + using T = TypeParam; + gtint_t n = 1; + gtint_t unit_inc = 1; + gtint_t idx = 42; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + idx = amaxv( n, nullptr, unit_inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking AMAXV with an invalid value of n. + idx = amaxv( n, x.data(), unit_inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + +} + +TYPED_TEST(amaxvIIT_ERS, n_eq_one_nonUnitStrides) +{ + using T = TypeParam; + gtint_t n = 1; + gtint_t inc = 5; + gtint_t idx = 42; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + idx = amaxv( n, nullptr, inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking AMAXV with an invalid value of n. + idx = amaxv( n, x.data(), inc ); + + // Computing the difference. + computediff( "idx", idx, gtint_t(0) ); +} + +#endif diff --git a/gtestsuite/testsuite/level1/amaxv/camaxv_generic.cpp b/gtestsuite/testsuite/level1/amaxv/camaxv_generic.cpp index 1f553cefef..7d3ae36c86 100644 --- a/gtestsuite/testsuite/level1/amaxv/camaxv_generic.cpp +++ b/gtestsuite/testsuite/level1/amaxv/camaxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,76 +35,77 @@ #include #include "test_amaxv.h" -class camaxvGenericTest : - public ::testing::TestWithParam> {}; +class camaxvGeneric : + public ::testing::TestWithParam> {}; //incx -// Tests using random integers as vector elements. -TEST_P( camaxvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( camaxvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). //---------------------------------------------------------- - // vector length: + // vector length gtint_t n = std::get<0>(GetParam()); - // stride size for x: + // stride size for x gtint_t incx = std::get<1>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_amaxv( n, incx, thresh ); + test_amaxv( n, incx ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class camaxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "icamax_"; -#elif TEST_CBLAS - std::string str_name = "cblas_icamax"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_camaxv"; -#endif - str_name += "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; +//Black box testing extended for different range of values +INSTANTIATE_TEST_SUITE_P( + Blackbox_Small_Sizes, + camaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(11), 1), // n size of vector takes values from 1 to 11 with step size of 1. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_Average_Sizes, + camaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(100), gtint_t(502), 50), // n size of vector takes values from 100 to 500 with step size of 50. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); -// Black box testing for generic and main use of camaxv. INSTANTIATE_TEST_SUITE_P( - Blackbox, - camaxvGenericTest, + Blackbox_Max_Sizes, + camaxvGeneric, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Range(gtint_t(1024), gtint_t(65535), 1023), // n size of vector takes values from 2pow10 to 2pow16-1 with step size of 1023. ::testing::Values(gtint_t(1)) // stride size for x ), - ::camaxvGenericTestPrint() + ::amaxvGenericPrint() + ); + +//Non unit testing extended for different stride values +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements_Stride, + camaxvGeneric, + ::testing::Combine( + ::testing::Values(gtint_t(123), gtint_t(111), gtint_t(20)), // m size of vector + ::testing::Values(gtint_t(4), gtint_t(7)) // stride size for x + ), + ::amaxvGenericPrint() ); -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitIncrements, - camaxvGenericTest, + Blackbox_Stride_Greater, + camaxvGeneric, ::testing::Combine( - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)) // stride size for x + ::testing::Range(gtint_t(1), gtint_t(10), 1), // n size of vector takes values from 1 to 10 with step size 1 + ::testing::Values(gtint_t(11)) // stride size for x ), - ::camaxvGenericTestPrint() + ::amaxvGenericPrint() ); diff --git a/gtestsuite/testsuite/level1/amaxv/damaxv_evt.cpp b/gtestsuite/testsuite/level1/amaxv/damaxv_evt.cpp new file mode 100644 index 0000000000..12720d4dc0 --- /dev/null +++ b/gtestsuite/testsuite/level1/amaxv/damaxv_evt.cpp @@ -0,0 +1,196 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_amaxv.h" + +class DISABLED_damaxvEVT : + public ::testing::TestWithParam> {}; // xj_exval + +// Tests using random values as vector elements. +TEST_P( DISABLED_damaxvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length + gtint_t n = std::get<0>(GetParam()); + // stride size for x + gtint_t incx = std::get<1>(GetParam()); + // index for exval in x + gtint_t xi = std::get<2>(GetParam()); + // exval for index xi + T xi_exval = std::get<3>(GetParam()); + // index for exval in x + gtint_t xj = std::get<4>(GetParam()); + // exval for index xj + T xj_exval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_amaxv( n, incx, xi, xi_exval, xj, xj_exval ); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors(Zen3) : + DAMAXV currently uses the bli_damaxv_zen_int( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel. + + Kernel structure for bli_damaxv_zen_int( ... ) is as follows : + bli_damaxv_zen_int() --> bli_vec_absmax_double() --> bli_vec_search_double() + bli_vec_absmax_double() structure: + For unit strides : + Main loop : In blocks of 48 --> L48 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. + + bli_vec_search_double() structure: + For unit strides : + Main loop : In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. + + The sizes chosen are as follows(in accordance to the structure in bli_vec_absmax_double()) : + 176 : 3*L48 + L32 + 175 : 3*L48 + L16 + L8 + L4 + L2 + 1(LScalar) + + The following indices are sufficient to ensure code-coverage of loops : + 0 <= idx < 144 - In L48 + 144 <= idx < 160 - In L32(for size 176), in L16(for size 175) + 160 <= idx < 168 - In L8 + 168 <= idx < 172 - In L4 + 172 <= idx < 174 - In L2 + 174 <= idx < 175 - In LScalar + + These sizes and indices also ensure code coverage for bli_vec_search_double(). + The testsuite requires 2 indices(and 2 exception values) to be induced in the vector. +*/ + +// Exception value testing with unit strides +INSTANTIATE_TEST_SUITE_P( + unitStrides_zen3, + DISABLED_damaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(175), gtint_t(176)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(143), gtint_t(159), + gtint_t(167), gtint_t(171), gtint_t(173), + gtint_t(174)), // xi, index for exval in xi_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)), // xi_exval + ::testing::Values(gtint_t(5), gtint_t(140), gtint_t(155), + gtint_t(163), gtint_t(170), gtint_t(172)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); + +/* + Exception value testing on vectors(Zen4) : + damaxv currently uses the bli_damaxv_zen_int( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel. + + Kernel structure for bli_damaxv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. + + The sizes chosen are as follows : + 367 - 10*L32 + 5*L8 + 7(LScalar) + + The following indices are sufficient to ensure code-coverage of loops : + 0 <= idx < 320 - In L32 + 320 <= idx < 360 - In L8 + 360 <= idx < 367 - In LScalar + + The testsuite requires 2 indices(and 2 exception values) to be induced in the vector. +*/ + +// Exception value testing with unit strides +INSTANTIATE_TEST_SUITE_P( + unitStrides_zen4, + DISABLED_damaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(367)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(315), + gtint_t(340), gtint_t(363)), // xi, index for exval in xi_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)), // xi_exval + ::testing::Values(gtint_t(1), gtint_t(300), + gtint_t(327), gtint_t(366)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); + + +// Exception value testing with non-unit strides +INSTANTIATE_TEST_SUITE_P( + nonUnitStrides, + DISABLED_damaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(10)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)), // xi, index for exval in xi_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)), // xi_exval + ::testing::Values(gtint_t(5), gtint_t(9)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, double(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level1/amaxv/damaxv_generic.cpp b/gtestsuite/testsuite/level1/amaxv/damaxv_generic.cpp index 7646911796..ffd1f7c29c 100644 --- a/gtestsuite/testsuite/level1/amaxv/damaxv_generic.cpp +++ b/gtestsuite/testsuite/level1/amaxv/damaxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,76 +35,78 @@ #include #include "test_amaxv.h" -class damaxvGenericTest : - public ::testing::TestWithParam> {}; +class damaxvGeneric : + public ::testing::TestWithParam> {}; //incx -// Tests using random integers as vector elements. -TEST_P( damaxvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( damaxvGeneric, API ) { using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). //---------------------------------------------------------- - // vector length: + // vector length gtint_t n = std::get<0>(GetParam()); - // stride size for x: + // stride size for x gtint_t incx = std::get<1>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_amaxv( n, incx, thresh ); + test_amaxv( n, incx ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class damaxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "idamax_"; -#elif TEST_CBLAS - std::string str_name = "cblas_idamax"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_damaxv"; -#endif - str_name += "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; +//Black box testing extended for different range of values +INSTANTIATE_TEST_SUITE_P( + Blackbox_Small_Sizes, + damaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(11), 1), // n size of vector takes values from 1 to 11 with step size of 1. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); -// Black box testing for generic and main use of samaxv. INSTANTIATE_TEST_SUITE_P( - Blackbox, - damaxvGenericTest, + Blackbox_Average_Sizes, + damaxvGeneric, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Range(gtint_t(100), gtint_t(502), 50), // n size of vector takes values from 100 to 500 with step size of 50. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_Max_Sizes, + damaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(1024), gtint_t(65535), 1023), // n size of vector takes values from 2pow10 to 2pow16-1 with step size of 1023. ::testing::Values(gtint_t(1)) // stride size for x ), - ::damaxvGenericTestPrint() + ::amaxvGenericPrint() + ); + +//Non unit testing extended for different stride values +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements_Stride, + damaxvGeneric, + ::testing::Combine( + ::testing::Values(gtint_t(123), gtint_t(111), gtint_t(20)), // m size of vector + ::testing::Values(gtint_t(4), gtint_t(8)) // stride size for x + ), + ::amaxvGenericPrint() ); -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitIncrements, - damaxvGenericTest, + Blackbox_Stride_Greater, + damaxvGeneric, ::testing::Combine( - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)) // stride size for x + ::testing::Range(gtint_t(1), gtint_t(10), 1), // n size of vector takes values from 1 to 10 with step size 1 + ::testing::Values(gtint_t(11)) // stride size for x ), - ::damaxvGenericTestPrint() + ::amaxvGenericPrint() ); + diff --git a/gtestsuite/testsuite/level1/amaxv/samaxv_evt.cpp b/gtestsuite/testsuite/level1/amaxv/samaxv_evt.cpp new file mode 100644 index 0000000000..09e954ab02 --- /dev/null +++ b/gtestsuite/testsuite/level1/amaxv/samaxv_evt.cpp @@ -0,0 +1,172 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_amaxv.h" + +class DISABLED_samaxvEVT : + public ::testing::TestWithParam> {}; // xj_exval + +// Tests using random values as vector elements. +TEST_P( DISABLED_samaxvEVT, API ) +{ + + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length + gtint_t n = std::get<0>(GetParam()); + // stride size for x + gtint_t incx = std::get<1>(GetParam()); + // index for exval in x + gtint_t xi = std::get<2>(GetParam()); + // exval for index xi + T xi_exval = std::get<3>(GetParam()); + // index for exval in x + gtint_t xj = std::get<4>(GetParam()); + // exval for index xj + T xj_exval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_amaxv( n, incx, xi, xi_exval, xj, xj_exval ); +} + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors(Zen3) : + SAMAXV currently uses the bli_samaxv_zen_int( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel. + + Kernel structure for bli_samaxv_zen_int( ... ) is as follows : + Main loop : In blocks of 8 --> L8 + Fringe loops : Element-wise loop --> LScalar + + The sizes chosen are as follows : + 61 - 7*L8 + 5(LScalar) + + The following indices are sufficient to ensure code-coverage of loops : + 0 <= idx < 56 - In L8 + 56 <= idx < 61 - In LScalar + + The testsuite requires 2 indices(and 2 exception values) to set exception values in the vector. +*/ + +// Exception value testing with unit strides +INSTANTIATE_TEST_SUITE_P( + unitStrides_zen3, + DISABLED_samaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(61)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(48), + gtint_t(55), gtint_t(57)), // xi, index for exval in xi_exval + ::testing::Values(NaN, -Inf, Inf, float(2.3)), // xi_exval + ::testing::Values(gtint_t(1), gtint_t(33), + gtint_t(50), gtint_t(60)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, float(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); + +/* + Exception value testing on vectors(Zen4) : + SAMAXV currently uses the bli_samaxv_zen_int_avx512( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel. + + Kernel structure for bli_samaxv_zen_int_avx512( ... ) is as follows : + + For unit strides : + Main loop : In blocks of 80 --> L80 + Fringe loops : In blocks of 16 --> L16 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. + + The sizes chosen are as follows : + 461 - 5*L80 + 3*L16 + 13(LScalar) + + The following indices are sufficient to ensure code-coverage of loops : + 0 <= idx < 400 - In L80 + 400 <= idx < 448 - In L16 + 448 <= idx < 461 - In LScalar + + The testsuite requires 2 indices(and 2 exception values) to set exception values in the vector. +*/ +// Exception value testing with unit strides +INSTANTIATE_TEST_SUITE_P( + unitStrides_zen4, + DISABLED_samaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(461)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(347), + gtint_t(420), gtint_t(459)), // xi, index for exval in xi_exval + ::testing::Values(NaN, -Inf, Inf, float(2.3)), // xi_exval + ::testing::Values(gtint_t(101), gtint_t(252), + gtint_t(447), gtint_t(450)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, float(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); + + +// Exception value testing with non-unit strides +INSTANTIATE_TEST_SUITE_P( + nonUnitStrides, + DISABLED_samaxvEVT, + ::testing::Combine( + ::testing::Values(gtint_t(10)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)), // xi, index for exval in xi_exval + ::testing::Values(NaN, Inf, -Inf, float(2.3)), // xi_exval + ::testing::Values(gtint_t(1), gtint_t(9)), // xj, index for exval in xj_exval + ::testing::Values(NaN, -Inf, Inf, float(2.3)) // xj_exval + ), + ::amaxvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level1/amaxv/samaxv_generic.cpp b/gtestsuite/testsuite/level1/amaxv/samaxv_generic.cpp index 111d51423f..a6dd1ab2c5 100644 --- a/gtestsuite/testsuite/level1/amaxv/samaxv_generic.cpp +++ b/gtestsuite/testsuite/level1/amaxv/samaxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,76 +35,77 @@ #include #include "test_amaxv.h" -class samaxvGenericTest : - public ::testing::TestWithParam> {}; +class samaxvGeneric : + public ::testing::TestWithParam> {}; //incx -// Tests using random integers as vector elements. -TEST_P( samaxvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( samaxvGeneric, API ) { using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). //---------------------------------------------------------- - // vector length: + // vector length gtint_t n = std::get<0>(GetParam()); - // stride size for x: + // stride size for x gtint_t incx = std::get<1>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_amaxv( n, incx, thresh ); + test_amaxv( n, incx ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class samaxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "isamax_"; -#elif TEST_CBLAS - std::string str_name = "cblas_isamax"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_samaxv"; -#endif - str_name += "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; +//Black box testing extended for different range of values +INSTANTIATE_TEST_SUITE_P( + Blackbox_Small_Size, + samaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(11), 1), // n size of vector takes values from 1 to 11 with step size of 1. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_Average_Size, + samaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(100), gtint_t(502), 50), // n size of vector takes values from 100 to 500 with step size of 50. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); -// Black box testing for generic and main use of samaxv. INSTANTIATE_TEST_SUITE_P( - Blackbox, - samaxvGenericTest, + Blackbox_Max_Size, + samaxvGeneric, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // n size of vector takes values from 10 to 100 with step size of 10. + ::testing::Range(gtint_t(1024), gtint_t(65535), 1023), // n size of vector takes values from 2pow10 to 2pow16-1 with step size of 1023. ::testing::Values(gtint_t(1)) // stride size for x ), - ::samaxvGenericTestPrint() + ::amaxvGenericPrint() + ); + +//Non unit testing extended for different stride values +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements_Stride, + samaxvGeneric, + ::testing::Combine( + ::testing::Values(gtint_t(123), gtint_t(111), gtint_t(20)), // m size of vector + ::testing::Values(gtint_t(4), gtint_t(8)) // stride size for x + ), + ::amaxvGenericPrint() ); -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitIncrements, - samaxvGenericTest, + Blackbox_Stride_Greater, + samaxvGeneric, ::testing::Combine( - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)) // stride size for x + ::testing::Range(gtint_t(1), gtint_t(10), 1), // n size of vector takes values from 1 to 10 with step size 1 + ::testing::Values(gtint_t(11)) // stride size for x ), - ::samaxvGenericTestPrint() + ::amaxvGenericPrint() ); diff --git a/gtestsuite/testsuite/level1/amaxv/test_amaxv.h b/gtestsuite/testsuite/level1/amaxv/test_amaxv.h index a02464e8ee..04fe449a8d 100644 --- a/gtestsuite/testsuite/level1/amaxv/test_amaxv.h +++ b/gtestsuite/testsuite/level1/amaxv/test_amaxv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -43,7 +43,7 @@ */ template -void test_amaxv( gtint_t n, gtint_t incx, double thresh ) +static void test_amaxv( gtint_t n, gtint_t incx ) { //---------------------------------------------------------- // Initialize vectors with random numbers. @@ -63,5 +63,77 @@ void test_amaxv( gtint_t n, gtint_t incx, double thresh ) //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - EXPECT_EQ( idx, idx_ref ); + computediff( "idx", idx, idx_ref ); } + +/** + * @brief Generic test body for amaxv operation with extreme values. + */ +template +static void test_amaxv( gtint_t n, gtint_t incx, gtint_t xi, T xi_exval, + gtint_t xj, T xj_exval ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = xi_exval; + else return; + + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < xj && xj < n ) x[xj * abs(incx)] = xj_exval; + else return; + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + gtint_t idx_ref = testinghelpers::ref_amaxv( n, x.data(), incx ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + gtint_t idx = amaxv( n, x.data(), incx ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "idx", idx, idx_ref ); +} + +// Test-case logger : Used to print the test-case details when vectors have exception value. +class amaxvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + return str_name; + } +}; + +template +class amaxvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + gtint_t xi = std::get<2>(str.param); + T xi_exval = std::get<3>(str.param); + gtint_t xj = std::get<4>(str.param); + T xj_exval = std::get<5>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_X_" + std::to_string(xi) + "_" + testinghelpers::get_value_string(xi_exval); + str_name = str_name + "_" + std::to_string(xj) + "_" + testinghelpers::get_value_string(xj_exval); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/amaxv/zamaxv_generic.cpp b/gtestsuite/testsuite/level1/amaxv/zamaxv_generic.cpp index 9c35ed502b..0b2a0409ba 100644 --- a/gtestsuite/testsuite/level1/amaxv/zamaxv_generic.cpp +++ b/gtestsuite/testsuite/level1/amaxv/zamaxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,76 +35,77 @@ #include #include "test_amaxv.h" -class zamaxvGenericTest : - public ::testing::TestWithParam> {}; +class zamaxvGeneric : + public ::testing::TestWithParam> {}; //incx -// Tests using random integers as vector elements. -TEST_P( zamaxvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( zamaxvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). //---------------------------------------------------------- - // vector length: + // vector length gtint_t n = std::get<0>(GetParam()); - // stride size for x: + // stride size for x gtint_t incx = std::get<1>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_amaxv( n, incx, thresh ); + test_amaxv( n, incx ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zamaxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "izamax_"; -#elif TEST_CBLAS - std::string str_name = "cblas_izamax"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zamaxv"; -#endif - str_name += "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; +//Black box testing extended for different range of values +INSTANTIATE_TEST_SUITE_P( + Blackbox_Small_Sizes, + zamaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(11), 1), // n size of vector takes values from 1 to 11 with step size of 1. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_Average_Sizes, + zamaxvGeneric, + ::testing::Combine( + ::testing::Range(gtint_t(100), gtint_t(502), 50), // n size of vector takes values from 100 to 500 with step size of 50. + ::testing::Values(gtint_t(1)) // stride size for x + ), + ::amaxvGenericPrint() + ); -// Black box testing for generic and main use of zamaxv. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zamaxvGenericTest, + Blackbox_Max_Sizes, + zamaxvGeneric, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Range(gtint_t(1024), gtint_t(65535), 1023), // n size of vector takes values from 2pow10 to 2pow16-1 with step size of 1023. ::testing::Values(gtint_t(1)) // stride size for x ), - ::zamaxvGenericTestPrint() + ::amaxvGenericPrint() + ); + +//Non unit testing extended for different stride values +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements_Stride, + zamaxvGeneric, + ::testing::Combine( + ::testing::Values(gtint_t(123), gtint_t(111), gtint_t(20)), // m size of vector + ::testing::Values(gtint_t(4), gtint_t(8)) // stride size for x + ), + ::amaxvGenericPrint() ); -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitIncrements, - zamaxvGenericTest, + Blackbox_Stride_Greater, + zamaxvGeneric, ::testing::Combine( - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)) // stride size for x + ::testing::Range(gtint_t(1), gtint_t(10), 1), // n size of vector takes values from 1 to 10 with step size 1 + ::testing::Values(gtint_t(11)) // stride size for x ), - ::zamaxvGenericTestPrint() + ::amaxvGenericPrint() ); diff --git a/gtestsuite/testsuite/level1/axpbyv/IIT_ERS_test.cpp b/gtestsuite/testsuite/level1/axpbyv/IIT_ERS_test.cpp deleted file mode 100644 index 5e568b0655..0000000000 --- a/gtestsuite/testsuite/level1/axpbyv/IIT_ERS_test.cpp +++ /dev/null @@ -1,96 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "common/testing_helpers.h" -#include "axpbyv.h" -#include "inc/check_error.h" -#include "common/wrong_inputs_helpers.h" - -template -class Axpby_IIT_ERS_Test : public ::testing::Test {}; -typedef ::testing::Types TypeParam; // The supported datatypes from BLAS calls for AXPBY -TYPED_TEST_SUITE(Axpby_IIT_ERS_Test, TypeParam); // Defining individual testsuites based on the datatype support. - -// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. -using namespace testinghelpers::IIT; - -/* - Early Return Scenarios(ERS) : - - The AXPBY API is expected to return early in the following cases: - 1. When n < 0. - -*/ - -#ifdef TEST_BLAS - -// When n < 0 -TYPED_TEST(Axpby_IIT_ERS_Test, n_lt_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector y = testinghelpers::get_random_vector( -10, 10, N, INC ); - - T alpha, beta; - testinghelpers::initone( alpha ); - testinghelpers::initzero( beta ); - // Copy so that we check that the elements of C are not modified. - std::vector y_ref(y); - - axpbyv( CONJ, -1, alpha, nullptr, INC, beta, y.data(), INC ); - // Use bitwise comparison (no threshold). - computediff( N, y.data(), y_ref.data(), INC ); -} - -// When n = 0 -TYPED_TEST(Axpby_IIT_ERS_Test, n_eq_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector y = testinghelpers::get_random_vector( -10, 10, N, INC ); - - T alpha, beta; - testinghelpers::initone( alpha ); - testinghelpers::initzero( beta ); - // Copy so that we check that the elements of C are not modified. - std::vector y_ref(y); - - axpbyv( CONJ, 0, alpha, nullptr, INC, beta, y.data(), INC ); - // Use bitwise comparison (no threshold). - computediff( N, y.data(), y_ref.data(), INC ); -} - -#endif - diff --git a/gtestsuite/testsuite/level1/axpbyv/axpbyv.h b/gtestsuite/testsuite/level1/axpbyv/axpbyv.h index 0c415e1b0c..16c14a6a41 100644 --- a/gtestsuite/testsuite/level1/axpbyv/axpbyv.h +++ b/gtestsuite/testsuite/level1/axpbyv/axpbyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -66,6 +67,21 @@ static void axpbyv_(gtint_t n, T alpha, T* x, gtint_t incx, T beta, T* y, gtint_ throw std::runtime_error("Error in testsuite/level1/axpbyv.h: Invalid typename in axpbyv_()."); } +template +static void axpbyv_blis_impl(gtint_t n, T alpha, T* x, gtint_t incx, T beta, T* y, gtint_t incy) +{ + if constexpr (std::is_same::value) + saxpby_blis_impl( &n, &alpha, x, &incx, &beta, y, &incy ); + else if constexpr (std::is_same::value) + daxpby_blis_impl( &n, &alpha, x, &incx, &beta, y, &incy ); + else if constexpr (std::is_same::value) + caxpby_blis_impl( &n, &alpha, x, &incx, &beta, y, &incy ); + else if constexpr (std::is_same::value) + zaxpby_blis_impl( &n, &alpha, x, &incx, &beta, y, &incy ); + else + throw std::runtime_error("Error in testsuite/level1/axpbyv.h: Invalid typename in axpbyv_blis_impl()."); +} + template static void cblas_axpbyv(gtint_t n, T alpha, T* x, gtint_t incx, T beta, T* y, gtint_t incy) { @@ -102,8 +118,34 @@ static void typed_axpbyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T template static void axpbyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T beta, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_x_cpy = conj_x; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t incx_cpy = incx; + T beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS axpbyv_( n, alpha, x, incx, beta, y, incy ); +#elif TEST_BLAS_BLIS_IMPL + axpbyv_blis_impl( n, alpha, x, incx, beta, y, incy ); #elif TEST_CBLAS cblas_axpbyv( n, alpha, x, incx, beta, y, incy ); #elif TEST_BLIS_TYPED @@ -111,4 +153,27 @@ static void axpbyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T beta, #else throw std::runtime_error("Error in testsuite/level1/axpbyv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "beta", beta, beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/axpbyv/axpbyv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/axpbyv/axpbyv_IIT_ERS.cpp new file mode 100644 index 0000000000..f847cb2742 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpbyv/axpbyv_IIT_ERS.cpp @@ -0,0 +1,211 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "axpbyv.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class axpbyv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS/CBLAS calls for AXPBY +TYPED_TEST_SUITE(axpbyv_IIT_ERS, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) +/* + Early Return Scenarios(ERS) : + The early return cases for ?axpbyv are not defined under BLAS compliance. + Thus, the necessary cases to match the other standards are tested. + + The AXPBY API is expected to return early in the following cases: + 1. When n <= 0. + 2. When alpha is 0 and beta is 1. +*/ + +// Early return cases with non-unit strides on vectors +// When n < 0 +TYPED_TEST(axpbyv_IIT_ERS, n_lt_zero_nonUnitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initzero( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, -1, alpha, nullptr, 5, beta, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, -1, alpha, x.data(), 5, beta, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// When n = 0 +TYPED_TEST(axpbyv_IIT_ERS, n_eq_zero_nonUnitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initzero( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, 0, alpha, nullptr, 5, beta, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, 0, alpha, x.data(), 5, beta, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + + +TYPED_TEST(axpbyv_IIT_ERS, alpha_eq_zero_beta_eq_one_nonUnitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, N, alpha, nullptr, 5, beta, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, N, alpha, x.data(), 5, beta, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// Early return cases with unit strides on vectors +// When n < 0 +TYPED_TEST(axpbyv_IIT_ERS, n_lt_zero_unitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initzero( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, -1, alpha, nullptr, 1, beta, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, -1, alpha, x.data(), 1, beta, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} + +// When n = 0 +TYPED_TEST(axpbyv_IIT_ERS, n_eq_zero_unitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initzero( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, 0, alpha, nullptr, 1, beta, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, 0, alpha, x.data(), 1, beta, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} + +// When alpha = 0 and beta = 1 +TYPED_TEST(axpbyv_IIT_ERS, alpha_eq_zero_beta_eq_one_unitStrides) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpbyv( CONJ, N, alpha, nullptr, 1, beta, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpbyv( CONJ, N, alpha, x.data(), 1, beta, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} +#endif diff --git a/gtestsuite/testsuite/level1/axpbyv/caxpbyv_evt.cpp b/gtestsuite/testsuite/level1/axpbyv/caxpbyv_evt.cpp new file mode 100644 index 0000000000..b400e36b24 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpbyv/caxpbyv_evt.cpp @@ -0,0 +1,360 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv.h" + +class caxpbyvEVT : + public ::testing::TestWithParam> {}; // beta + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caxpbyvEVT); + +// Tests using random integers as vector elements. +TEST_P( caxpbyvEVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + // beta + T beta = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + // NOTE : Every mul for complex types involves 3 ops(2 muls + 1 add) + double thresh; + double adj = 3; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = (1 * adj) * testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = (1 * adj + 1) * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = (1 * adj) * testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = (2 * adj + 1) * testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv(conj_x, n, incx, incy, alpha, beta, xi, xexval, + yj, yexval, thresh); +} + +#if defined(REF_IS_NETLIB) +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +/* + The code structure for bli_caxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Fringe loops : In blocks of 12 --> L12 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + + For non-unit strides : A single loop, to process element wise. + NOTE : Any size, requiring the fringe case of 1 with unit stride falls to + the non-unit stride loop and executes it once for just the last element. + + The sizes chosen are as follows : + 71 - 4*L16 + L4 + 3(LScalar) + 72 - 4*L16 + L8 + 76 - 4*L16 + L12 + + For size 71 : 4*L16 + L4 + 3(LScalar) + Indices are : 0, 62 -> In L16 + 66 -> In L4 + 69 -> In LScalar + + For size 72 : 4*L16 + L8 + Indices are : 0, 62 -> In L16 + 70 -> In L8 + + For size 76 : 4*L16 + L12 + Indices are : 0, 62 -> In L16 + 74 -> In L12 + + The alpha and beta values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + such as 0.0 * { {NaN, 0}, {+Inf, 0}, {-Inf, 0}, ... }, and a few more. +*/ + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(71), gtint_t(72), gtint_t(76)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(62), gtint_t(66), + gtint_t(69), gtint_t(70), gtint_t(74)), // indices to set exception values on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(scomplex{0.0, 0.0}), // dummy value on y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(71), gtint_t(72), gtint_t(76)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(scomplex{0.0, 0.0}), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(62), gtint_t(66), + gtint_t(69), gtint_t(70), gtint_t(74)), // indices to set exception values on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(71), gtint_t(72), gtint_t(76)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(62), gtint_t(66), + gtint_t(69), gtint_t(70), gtint_t(74)), // indices to set exception values on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(62), gtint_t(66), + gtint_t(69), gtint_t(70), gtint_t(74)), // indices to set exception values on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.5}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.5}), // exception values to set on y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0}, scomplex{0.0, 1.0}, + scomplex{0.0, -1.0}, scomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +/* + Exception value testing on alpha and beta : + Alpha values are set to Nan, +Inf or -Inf. A dummy + value of 0.0 is induced in X and Y vectors, to further + verify the propagation. +*/ +INSTANTIATE_TEST_SUITE_P( + alphaBeta_unitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(71), gtint_t(72), gtint_t(76)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(scomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(scomplex{0.0, 0.0}), + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.7}), // alpha + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on alpha) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_nonUnitStrides, + caxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(25)), // indices to set zero on x + ::testing::Values(scomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(scomplex{0.0, 0.0}), + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.7}), // alpha + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{2.3, -3.7}) // beta + ), + ::axpbyvEVTPrint()); +#endif diff --git a/gtestsuite/testsuite/level1/axpbyv/caxpbyv_generic.cpp b/gtestsuite/testsuite/level1/axpbyv/caxpbyv_generic.cpp index 93f71b3412..6d80f9851d 100644 --- a/gtestsuite/testsuite/level1/axpbyv/caxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpbyv/caxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_axpbyv.h" -class caxpbyvGenericTest : +class caxpbyvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( caxpbyvGenericTest, RandomData ) +TEST_P( caxpbyvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -64,7 +64,45 @@ TEST_P( caxpbyvGenericTest, RandomData ) T beta = std::get<5>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + // NOTE : Every mul for complex types involves 3 ops(2 muls + 1 add) + double thresh; + double adj = 3; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = (1 * adj) * testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = (1 * adj + 1) * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = (1 * adj) * testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = (2 * adj + 1) * testinghelpers::getEpsilon(); + } //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,60 +110,23 @@ TEST_P( caxpbyvGenericTest, RandomData ) test_axpbyv( conj_x, n, incx, incy, alpha, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class caxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - scomplex alpha = std::get<4>(str.param); - scomplex beta = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "caxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_caxpby"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_caxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - // Black box testing for generic and main use of caxpby. INSTANTIATE_TEST_SUITE_P( Blackbox, - caxpbyvGenericTest, + caxpbyvGeneric, ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) + ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , 'c' // this option is BLIS-api specific. #endif ), - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha - ::testing::Values(scomplex{1.0, 2.0}) // beta + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{2.2, -3.3}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{1.0, 2.0}) // beta ), - ::caxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); // Test for non-unit increments. @@ -133,20 +134,20 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - caxpbyvGenericTest, + caxpbyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , 'c' // this option is BLIS-api specific. #endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2)), // stride size for x - ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(scomplex{1.0, -2.0}) // beta + ), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{2.2, -3.3}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{1.0, 2.0}) // beta ), - ::caxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -155,15 +156,15 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - caxpbyvGenericTest, + caxpbyvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: use x - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-11), gtint_t(5)), // stride size for x - ::testing::Values(gtint_t(-3), gtint_t(7)), // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(scomplex{1.0, -2.0}) // beta + ::testing::Values('n'), // n: use x + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-11), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(-3), gtint_t(7)), // stride size for y + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{2.2, -3.3}), // alpha + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{1.0, 2.0}) // beta ), - ::caxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/axpbyv/daxpbyv_evt.cpp b/gtestsuite/testsuite/level1/axpbyv/daxpbyv_evt.cpp new file mode 100644 index 0000000000..4168f668b5 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpbyv/daxpbyv_evt.cpp @@ -0,0 +1,287 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv.h" + +class daxpbyvEVT : + public ::testing::TestWithParam> {}; // beta + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daxpbyvEVT); + +// Tests using random values as vector elements, +// with exception values on the passed indices. +TEST_P( daxpbyvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + // beta + T beta = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = 2 * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = 3 * testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv(conj_x, n, incx, incy, alpha, beta, xi, xexval, + yj, yexval, thresh); +} + +#if defined(REF_IS_NETLIB) +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors : + DAXPBY currently uses the bli_daxpbyv_zen_int10( ... ) kernel for computation. + The size and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure : + Main loop : In blocks of 40 --> L40 + Fringe loops : In blocks of 20 --> L20 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For size 115 : L40*2 + L20 + L8 + L4 + 3(LScalar) + Indices are : 0, 79 -> In L40 + 99 -> In L20 + 107 -> In L8 + 111 -> In L4 + 114 -> In LScalar + + For size 116 : L40*2 + L20 + L16 + Indices are : 0, 79 -> In L40 + 99 -> In L20 + 107 -> In L16 + + The alpha and beta values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(115), gtint_t(116)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(79), gtint_t(99), + gtint_t(107), gtint_t(111), gtint_t(114)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(double(0.0)), // dummy value on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)), // alpha + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(115), gtint_t(116)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(double(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(79), gtint_t(99), + gtint_t(107), gtint_t(111), gtint_t(114)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)), // alpha + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(115), gtint_t(116)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(79), gtint_t(99), + gtint_t(107), gtint_t(111), gtint_t(114)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(79), gtint_t(99), + gtint_t(107), gtint_t(111), gtint_t(114)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)), // alpha + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf, 2.9), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf, -1.5), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)), // alpha + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +/* + Exception value testing on alpha and/or beta : + Alpha and/or beta values are set to Nan, +Inf or -Inf. + Also, a normal value is given to alpha and beta to check + for combinations where only X or Y involve scaling by an + exception valued scalar. A dummy value of 0.0 is induced + in X and Y vectors, to further verify the propagation. + + The size for the instantiators is chosen such that + code coverage is ensured in the respective kernel. +*/ +// Exception value testing(on alpha/beta) with unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_unitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(115), gtint_t(116)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(double(0.0)), + ::testing::Values(NaN, -Inf, Inf, 2.3), // alpha + ::testing::Values(NaN, -Inf, Inf, -1.9) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on alpha/beta) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_nonUnitStrides, + daxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(50)), // n, size of vector with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(25)), // indices to set zero on x + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(double(0.0)), + ::testing::Values(NaN, -Inf, Inf, 2.3), // alpha + ::testing::Values(NaN, -Inf, Inf, -1.9) // beta + ), + ::axpbyvEVTPrint()); +#endif diff --git a/gtestsuite/testsuite/level1/axpbyv/daxpbyv_generic.cpp b/gtestsuite/testsuite/level1/axpbyv/daxpbyv_generic.cpp index 96d94cf887..b25af1f6ce 100644 --- a/gtestsuite/testsuite/level1/axpbyv/daxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpbyv/daxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_axpbyv.h" -class daxpbyvGenericTest : +class daxpbyvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( daxpbyvGenericTest, RandomData ) +TEST_P( daxpbyvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -64,7 +64,42 @@ TEST_P( daxpbyvGenericTest, RandomData ) T beta = std::get<5>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = 2 * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = 3 * testinghelpers::getEpsilon(); + } //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,54 +107,21 @@ TEST_P( daxpbyvGenericTest, RandomData ) test_axpbyv( conj_x, n, incx, incy, alpha, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class daxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - double alpha = std::get<4>(str.param); - double beta = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "daxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_daxpby"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_daxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - -// Black box testing for generic and main use of caxpy. +// Black box testing for generic and main use of daxpby. INSTANTIATE_TEST_SUITE_P( Blackbox, - daxpbyvGenericTest, + daxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0), double(-2.0)), // alpha - ::testing::Values(double(-1.0)) // beta + ::testing::Values(double(2.3), double(1.0), + double(-1.0), double(0.0)), // alpha + ::testing::Values(double(-4.9), double(1.0), + double(-1.0), double(0.0)) // beta ), - ::daxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #ifdef TEST_BLIS_TYPED @@ -128,16 +130,18 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - daxpbyvGenericTest, + daxpbyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0)), // alpha - ::testing::Values(double(1.0)) // beta + ::testing::Values(double(2.3), double(1.0), + double(-1.0), double(0.0)), // alpha + ::testing::Values(double(-4.9), double(1.0), + double(-1.0), double(0.0)) // beta ), - ::daxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #endif @@ -145,8 +149,8 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - daxpbyvGenericTest, + nonUnitPositiveIncrements, + daxpbyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED @@ -156,10 +160,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(7)), // stride size for x ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(4.0), // alpha - ::testing::Values(-2.0) // beta + ::testing::Values(double(2.3), double(1.0), + double(-1.0), double(0.0)), // alpha + ::testing::Values(double(-4.9), double(1.0), + double(-1.0), double(0.0)) // beta ), - ::daxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -167,16 +173,18 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - daxpbyvGenericTest, + negativeIncrements, + daxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(11), gtint_t(-11)), // stride size for x ::testing::Values(gtint_t(-3), gtint_t(4)), // stride size for y - ::testing::Values(4.0), // alpha - ::testing::Values(-2.0) // beta + ::testing::Values(double(2.3), double(1.0), + double(-1.0), double(0.0)), // alpha + ::testing::Values(double(-4.9), double(1.0), + double(-1.0), double(0.0)) // beta ), - ::daxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/axpbyv/saxpbyv_evt.cpp b/gtestsuite/testsuite/level1/axpbyv/saxpbyv_evt.cpp new file mode 100644 index 0000000000..708f7f5cc4 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpbyv/saxpbyv_evt.cpp @@ -0,0 +1,287 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv.h" + +class saxpbyvEVT : + public ::testing::TestWithParam> {}; // beta + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saxpbyvEVT); + +// Tests using random values as vector elements, +// with exception values on the passed indices. +TEST_P( saxpbyvEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + // beta + T beta = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = 2 * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = 3 * testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv(conj_x, n, incx, incy, alpha, beta, xi, xexval, + yj, yexval, thresh); +} + +#if defined(REF_IS_NETLIB) +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors : + DAXPBY currently uses the bli_saxpbyv_zen_int10( ... ) kernel for computation. + The size and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure : + Main loop : In blocks of 80 --> L00 + Fringe loops : In blocks of 20 --> L40 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For size 231 : L80*2 + L40 + L16 + L8 + 7(LScalar) + Indices are : 0, 159 -> In L80 + 199 -> In L40 + 215 -> In L16 + 223 -> In L8 + 230 -> In LScalar + + For size 232 : L80*2 + L40 + L32 + Indices are : 0, 159 -> In L80 + 199 -> In L40 + 215 -> In L32r + + The alpha and beta values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(231), gtint_t(232)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(159), gtint_t(199), + gtint_t(215), gtint_t(223), gtint_t(230)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(float(0.0)), // dummy value on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)), // alpha + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(231), gtint_t(232)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(float(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(159), gtint_t(199), + gtint_t(215), gtint_t(223), gtint_t(230)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)), // alpha + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(231), gtint_t(232)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(159), gtint_t(199), + gtint_t(215), gtint_t(223), gtint_t(230)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(159), gtint_t(199), + gtint_t(215), gtint_t(223), gtint_t(230)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)), // alpha + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf, 2.9), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf, -1.5), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)), // alpha + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(4.5)) // beta + ), + ::axpbyvEVTPrint()); + +/* + Exception value testing on alpha and/or beta : + Alpha and/or beta values are set to Nan, +Inf or -Inf. + Also, a normal value is given to alpha and beta to check + for combinations where only X or Y involve scaling by an + exception valued scalar. A dummy value of 0.0 is induced + in X and Y vectors, to further verify the propagation. + + The size for the instantiators is chosen such that + code coverage is ensured in the respective kernel. +*/ +// Exception value testing(on alpha/beta) with unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_unitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(231), gtint_t(232)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(float(0.0)), + ::testing::Values(NaN, -Inf, Inf, 2.3), // alpha + ::testing::Values(NaN, -Inf, Inf, -1.9) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on alpha/beta) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_nonUnitStrides, + saxpbyvEVT, + ::testing::Combine( + ::testing::Values('n'), // use conjx as n for real types + ::testing::Values(gtint_t(50)), // n, size of vector with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(25)), // indices to set zero on x + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(float(0.0)), + ::testing::Values(NaN, -Inf, Inf, 2.3), // alpha + ::testing::Values(NaN, -Inf, Inf, -1.9) // beta + ), + ::axpbyvEVTPrint()); +#endif diff --git a/gtestsuite/testsuite/level1/axpbyv/saxpbyv_generic.cpp b/gtestsuite/testsuite/level1/axpbyv/saxpbyv_generic.cpp index a9aeb9f5a8..e44e708185 100644 --- a/gtestsuite/testsuite/level1/axpbyv/saxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpbyv/saxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_axpbyv.h" -class saxpbyvGenericTest : +class saxpbyvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( saxpbyvGenericTest, RandomData ) +TEST_P( saxpbyvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -64,7 +64,42 @@ TEST_P( saxpbyvGenericTest, RandomData ) T beta = std::get<5>(GetParam()); // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = 2 * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = 3 * testinghelpers::getEpsilon(); + } //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,54 +107,21 @@ TEST_P( saxpbyvGenericTest, RandomData ) test_axpbyv( conj_x, n, incx, incy, alpha, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class saxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - float alpha = std::get<4>(str.param); - float beta = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "saxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_saxpby"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_saxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - // Black box testing for generic and main use of caxpy. INSTANTIATE_TEST_SUITE_P( Blackbox, - saxpbyvGenericTest, + saxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0), float(-2.0)), // alpha - ::testing::Values(float(-1.0)) // beta + ::testing::Values(float(2.3), float(1.0), + float(-1.0), float(0.0)), // alpha + ::testing::Values(float(-4.9), float(1.0), + float(-1.0), float(0.0)) // beta ), - ::saxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #ifdef TEST_BLIS_TYPED @@ -128,16 +130,18 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - saxpbyvGenericTest, + saxpbyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0)), // alpha - ::testing::Values(float(1.0)) // beta + ::testing::Values(float(2.3), float(1.0), + float(-1.0), float(0.0)), // alpha + ::testing::Values(float(-4.9), float(1.0), + float(-1.0), float(0.0)) // beta ), - ::saxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #endif @@ -146,16 +150,18 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - saxpbyvGenericTest, + saxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(11)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x ::testing::Values(gtint_t(3)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(float(4.0)), // alpha - ::testing::Values(float(2.0)) // beta + ::testing::Values(float(2.3), float(1.0), + float(-1.0), float(0.0)), // alpha + ::testing::Values(float(-4.9), float(1.0), + float(-1.0), float(0.0)) // beta ), - ::saxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -164,15 +170,17 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - saxpbyvGenericTest, + saxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(11), gtint_t(-11)), // stride size for x ::testing::Values(gtint_t(-3), gtint_t(4)), // stride size for y - ::testing::Values(4.0), // alpha - ::testing::Values(-2.0) // beta + ::testing::Values(float(2.3), float(1.0), + float(-1.0), float(0.0)), // alpha + ::testing::Values(float(-4.9), float(1.0), + float(-1.0), float(0.0)) // beta ), - ::saxpbyvGenericTestPrint() + ::axpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/axpbyv/test_axpbyv.h b/gtestsuite/testsuite/level1/axpbyv/test_axpbyv.h index 7c6bf72eb0..7480dda9df 100644 --- a/gtestsuite/testsuite/level1/axpbyv/test_axpbyv.h +++ b/gtestsuite/testsuite/level1/axpbyv/test_axpbyv.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -67,7 +67,7 @@ static void test_axpbyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } template @@ -81,8 +81,13 @@ static void test_axpbyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); - x[xi*incx] = xexval; - y[yj*incy] = yexval; + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = xexval; + else return; + + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < yj && yj < n ) y[yj * abs(incy)] = yexval; + else return; //---------------------------------------------------------- // Call reference implementation to get ref results. @@ -99,5 +104,64 @@ static void test_axpbyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh, true ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh, true ); } + +// Test-case logger : Used to print the test-case details based on parameters +template +class axpbyvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + T beta = std::get<5>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + return str_name; + } +}; + +template +class axpbyvEVTPrint +{ +public: + std::string operator()( + testing::TestParamInfo> str) const + { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + gtint_t xi = std::get<4>(str.param); + T xexval = std::get<5>(str.param); + gtint_t yj = std::get<6>(str.param); + T yexval = std::get<7>(str.param); + T alpha = std::get<8>(str.param); + T beta = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + std::string xexval_str = testinghelpers::get_value_string(xexval); + std::string yexval_str = testinghelpers::get_value_string(yexval); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + xexval_str; + str_name = str_name + "_Y_" + std::to_string(yj); + str_name = str_name + "_" + yexval_str; + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt.cpp b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt.cpp new file mode 100644 index 0000000000..63ff1f13f5 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt.cpp @@ -0,0 +1,360 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv.h" + +class zaxpbyvEVT : + public ::testing::TestWithParam> {}; // beta + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaxpbyvEVT); + +// Tests using random integers as vector elements. +TEST_P( zaxpbyvEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + // beta + T beta = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + // NOTE : Every mul for complex types involves 3 ops(2 muls + 1 add) + double thresh; + double adj = 3; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = (1 * adj) * testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = (1 * adj + 1) * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = (1 * adj) * testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = (2 * adj + 1) * testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv(conj_x, n, incx, incy, alpha, beta, xi, xexval, + yj, yexval, thresh); +} + +#if defined(REF_IS_NETLIB) +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/* + The code structure for bli_zaxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + + For non-unit strides : A single loop, to process element wise. + NOTE : Any size, requiring the fringe case of 1 with unit stride falls to + the non-unit stride loop and executes it once for just the last element. + + The sizes chosen are as follows : + 59 - 7*L8 + L2 + 1(LScalar) + 60 - 7*L8 + L4 + 62 - 7*L8 + L6 + + For size 59 : 7*L8 + L2 + 1(LScalar) + Indices are : 0, 55 -> In L8 + 57 -> In L2 + 58 -> In LScalar + + For size 60 : 7*L8 + L4 + Indices are : 0, 55 -> In L8 + 59 -> In L4 + + For size 62 : 7*L8 + L6 + Indices are : 0, 55 -> In L8 + 61 -> In L6 + + The alpha and beta values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + such as 0.0 * { {NaN, 0}, {+Inf, 0}, {-Inf, 0}, ... }, and a few more. +*/ + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(59), gtint_t(60), gtint_t(62)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(55), gtint_t(57), + gtint_t(58), gtint_t(59), gtint_t(61)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(dcomplex{0.0, 0.0}), // dummy value on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(59), gtint_t(60), gtint_t(62)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(dcomplex{0.0, 0.0}), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(55), gtint_t(57), + gtint_t(58), gtint_t(59), gtint_t(61)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(59), gtint_t(60), gtint_t(62)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(55), gtint_t(57), + gtint_t(58), gtint_t(59), gtint_t(61)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(55), gtint_t(57), + gtint_t(58), gtint_t(59), gtint_t(61)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.5}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.5}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // beta + ), + ::axpbyvEVTPrint()); + +/* + Exception value testing on alpha and beta : + Alpha values are set to Nan, +Inf or -Inf. A dummy + value of 0.0 is induced in X and Y vectors, to further + verify the propagation. +*/ +INSTANTIATE_TEST_SUITE_P( + alphaBeta_unitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(59), gtint_t(60), gtint_t(62)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.7}), // alpha + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.7}) // beta + ), + ::axpbyvEVTPrint()); + +// Exception value testing(on alpha) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alphaBeta_nonUnitStrides, + zaxpbyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(25)), // indices to set zero on x + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.7}), // alpha + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.7}) // beta + ), + ::axpbyvEVTPrint()); +#endif diff --git a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt_testing.cpp b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt_testing.cpp deleted file mode 100644 index 104b5d59c1..0000000000 --- a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_evt_testing.cpp +++ /dev/null @@ -1,372 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_axpbyv.h" - -class zaxpbyvEVTTest : - public ::testing::TestWithParam> {}; -// Tests using random integers as vector elements. -TEST_P(zaxpbyvEVTTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether x or conj(x) will be added to y: - char conj_x = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // index for exval in x - gtint_t xi = std::get<4>(GetParam()); - // exval for x - T xexval = std::get<5>(GetParam()); - // index for exval in y - gtint_t yj = std::get<6>(GetParam()); - // exval for x - T yexval = std::get<7>(GetParam()); - // alpha - T alpha = std::get<8>(GetParam()); - // beta - T beta = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = 20 * testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_axpbyv(conj_x, n, incx, incy, alpha, beta, xi, xexval, - yj, yexval, thresh); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zaxpbyvEVTVecPrint -{ -public: - std::string operator()( - testing::TestParamInfo> str) const - { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - gtint_t xi = std::get<4>(str.param); - dcomplex xexval = std::get<5>(str.param); - gtint_t yj = std::get<6>(str.param); - dcomplex yexval = std::get<7>(str.param); - dcomplex alpha = std::get<8>(str.param); - dcomplex beta = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zaxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zaxpby"; -#else // #elif TEST_BLIS_TYPED - std::string str_name = "bli_zaxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = (incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = (incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string xexval_str = testinghelpers::get_value_string(xexval); - std::string yexval_str = testinghelpers::get_value_string(yexval); - str_name = str_name + "_X_" + std::to_string(xi); - str_name = str_name + "_" + xexval_str; - str_name = str_name + "_Y_" + std::to_string(yj); - str_name = str_name + "_" + yexval_str; - std::string alpha_str = testinghelpers::get_value_string(alpha); - std::string beta_str = testinghelpers::get_value_string(beta); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - -class zaxpbyvAlphaBetaPrint -{ -public: - std::string operator()( - testing::TestParamInfo> str) const - { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - dcomplex alpha = std::get<8>(str.param); - dcomplex beta = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zaxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zaxpby"; -#else // #elif TEST_BLIS_TYPED - std::string str_name = "bli_zaxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = (incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = (incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = testinghelpers::get_value_string(alpha); - std::string beta_str = testinghelpers::get_value_string(beta); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - -static double NaN = std::numeric_limits::quiet_NaN(); -static double Inf = std::numeric_limits::infinity(); - -/* - The code structure for bli_zaxpbyv_zen_int( ... ) is as follows : - For unit strides : - Main loop : In blocks of 8 --> L8 - Fringe loops : In blocks of 6 --> L6 - In blocks of 4 --> L4 - In blocks of 2 --> L2 - - For non-unit strides : A single loop, to process element wise. - NOTE : Any size, requiring the fringe case of 1 with unit stride falls to - the non-unit stride loop and executes it once for just the last element. - - With regards to exception value testing, every loop is tested separately. - The indices for setting exception values on the vectors are such that - every load associated with the loop has an exception value in it. Thus, - every arithmetic instruction associated with each load will be tested - for exception value handling. -*/ - -// Exception value testing(on vectors) for L8 -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_vec_L8, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(8)), // m, size of vector to enter L8 directly. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(4), gtint_t(7)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on x - ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(5), gtint_t(6)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on vectors) for L6 -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_vec_L6, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(6)), // m, size of vector to enter L8 directly. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(4)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on x - ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(5)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on vectors) for L4 -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_vec_L4, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(4)), // m, size of vector to enter L8 directly. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(1), gtint_t(3)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on x - ::testing::Values(gtint_t(0), gtint_t(2)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on vectors) for L2 -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_vec_L2, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(2)), // m, size of vector to enter L8 directly. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(1)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on x - ::testing::Values(gtint_t(0)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, - dcomplex{NaN, -Inf}), // exception values to set on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on vectors) with non unit strides -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_vec_NUS, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(1), gtint_t(5)), // m, size of vector to enter NUS loop directly. - ::testing::Values(gtint_t(3)), // stride size for x - ::testing::Values(gtint_t(-4)), // stride size for y - ::testing::Values(gtint_t(0)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}), // exception values to set on x - ::testing::Values(gtint_t(0)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, - dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}), // exception values to set on y - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha - ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on alpha/beta) with unit stride -/* - NOTE : Here, every loop is tested for, with alpha and beta having exception values - Furthermore, the first element of x and second element of y are set to 0, which - includes testing that cover cases where NaN might be induced due to 0 * (Inf or -Inf). -*/ -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_alphabeta_US, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(8), gtint_t(6), gtint_t(4), gtint_t(2)), // m size of vector to enter L8, L6, L4 and L2 respectively. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on x - ::testing::Values(gtint_t(1)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on y - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, dcomplex{-Inf, NaN}), // alpha - ::testing::Values(dcomplex{-0.9, NaN}, dcomplex{0.0, -Inf}, dcomplex{NaN, Inf}) // beta - ), - ::zaxpbyvEVTVecPrint()); - -// Exception value testing(on alpha/beta) with non-unit stride -INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_evt_alphabeta_NUS, - zaxpbyvEVTTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) -#ifdef TEST_BLIS_TYPED - , - 'c' // this option is BLIS-api specific. -#endif - ), - ::testing::Values(gtint_t(5)), // m, size of vector to enter NUS loop directly. - ::testing::Values(gtint_t(3)), // stride size for x - ::testing::Values(gtint_t(-4)), // stride size for y - ::testing::Values(gtint_t(0)), // indices to set exception values on x - ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on x - ::testing::Values(gtint_t(0)), // indices to set exception values on y - ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on y - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, dcomplex{-Inf, NaN}), // alpha - ::testing::Values(dcomplex{-0.9, NaN}, dcomplex{0.0, -Inf}, dcomplex{NaN, Inf}) // beta - ), - ::zaxpbyvEVTVecPrint()); diff --git a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp index b69a132796..aa476df48f 100644 --- a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include #include "test_axpbyv.h" -class zaxpbyvAccTest : +class zaxpbyvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P(zaxpbyvAccTest, RandomData) +TEST_P( zaxpbyvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -64,7 +64,45 @@ TEST_P(zaxpbyvAccTest, RandomData) T beta = std::get<5>(GetParam()); // Set the threshold for the errors: - double thresh = 20 * testinghelpers::getEpsilon(); + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + // NOTE : Every mul for complex types involves 3 ops(2 muls + 1 add) + double thresh; + double adj = 3; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + { + // Like SETV or COPYV(no ops) + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + // Like SCAL2V(1 mul) + else + thresh = (1 * adj) * testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + // Like ERS(no ops) + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + // Like ADDV(1 add) + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + // Like AXPYV(1 mul and 1 add) + else + thresh = (1 * adj + 1) * testinghelpers::getEpsilon(); + } + else + { + // Like SCALV(1 mul) + if (alpha == testinghelpers::ZERO()) + thresh = (1 * adj) * testinghelpers::getEpsilon(); + // Like AXPBYV(2 muls and 1 add) + else + thresh = (2 * adj + 1) * testinghelpers::getEpsilon(); + } //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,45 +110,6 @@ TEST_P(zaxpbyvAccTest, RandomData) test_axpbyv(conj_x, n, incx, incy, alpha, beta, thresh); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zaxpbyvAccTestPrint -{ -public: - std::string operator()( - testing::TestParamInfo> str) const - { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - dcomplex alpha = std::get<4>(str.param); - dcomplex beta = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zaxpby_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zaxpby"; -#else // #elif TEST_BLIS_TYPED - std::string str_name = "bli_zaxpbyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = (incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = (incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = (alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + ((alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = (beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + ((beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - /* The code structure for bli_zaxpbyv_zen_int( ... ) is as follows : For unit strides : @@ -126,8 +125,8 @@ class zaxpbyvAccTestPrint // Accuracy testing of the main loop, single and multiple runs INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_acc_US_main, - zaxpbyvAccTest, + bli_zaxpbyv_zen_int_acc_unitStrides_main, + zaxpbyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -141,12 +140,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvAccTestPrint()); + ::axpbyvGenericPrint()); // Accuracy testing of different combinations of fringe loops(L6, L4, L2, 1) INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_acc_US_fringe, - zaxpbyvAccTest, + bli_zaxpbyv_zen_int_acc_unitStrides_fringe, + zaxpbyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -160,12 +159,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvAccTestPrint()); + ::axpbyvGenericPrint()); // Accuracy testing of 3*L8 + L6 + L4 + L2 + 1, a case of main + all fringe cases taken INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_acc_US_combine, - zaxpbyvAccTest, + bli_zaxpbyv_zen_int_acc_unitStrides_combine, + zaxpbyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -179,12 +178,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvAccTestPrint()); + ::axpbyvGenericPrint()); // Accuracy testing with non-unit strides INSTANTIATE_TEST_SUITE_P( - bli_zaxpbyv_zen_int_acc_NUS, - zaxpbyvAccTest, + bli_zaxpbyv_zen_int_acc_nonUnitStrides, + zaxpbyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -193,9 +192,17 @@ INSTANTIATE_TEST_SUITE_P( #endif ), ::testing::Values(gtint_t(10), gtint_t(17)), // m - ::testing::Values(gtint_t(-3), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(6), gtint_t(-2)), // stride size for y + ::testing::Values( +#ifndef TEST_BLIS_TYPED + gtint_t(-3), +#endif + gtint_t(4)), // stride size for x + ::testing::Values( +#ifndef TEST_BLIS_TYPED + gtint_t(-2), +#endif + gtint_t(6)), // stride size for y ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvAccTestPrint()); + ::axpbyvGenericPrint()); diff --git a/gtestsuite/testsuite/level1/axpyf/axpyf.h b/gtestsuite/testsuite/level1/axpyf/axpyf.h new file mode 100644 index 0000000000..d1566df796 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyf/axpyf.h @@ -0,0 +1,175 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +static void typed_axpyf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T* y, + gtint_t incy) +{ + conj_t conja; + conj_t conjx; + // Map parameter characters to BLIS constants. + testinghelpers::char_to_blis_conj( conj_a, &conja ); + testinghelpers::char_to_blis_conj( conj_x, &conjx ); + if constexpr (std::is_same::value) + bli_saxpyf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_daxpyf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_caxpyf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_zaxpyf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, y, incy ); + else + throw std::runtime_error("Error in testsuite/level1/axpyv.h: Invalid typename in typed_axpyv()."); +} + +template +static void axpyf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T* y, + gtint_t incy +) +{ + +#ifdef TEST_UPPERCASE_ARGS + conj_a = static_cast(std::toupper(static_cast(conj_a))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_a_cpy = conj_a; + char conj_x_cpy = conj_x; + gtint_t m_cpy = m; + gtint_t b_cpy = b; + T* alpha_cpy = alpha; + gtint_t inca_cpy = inca; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* A_cpy = nullptr; + gtint_t size_A = testinghelpers::matsize( 'c', 'n', m, b, lda ); + if (A && size_A > 0) + { + A_cpy = new T[size_A]; + memcpy( A_cpy, A, size_A * sizeof( T ) ); + } + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( m, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + +/** + * axpyf operation is defined as : + * y := y + alpha * conja(A) * conjx(x) + * where A is an m x b matrix, and y and x are vectors. + * Matrix should be represented as "A" instead of "a" to distinguish it from vector. +*/ + typed_axpyf( + conj_a, + conj_x, + m, + b, + alpha, + A, + inca, + lda, + x, + incx, + y, + incy ); + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_a", conj_a, conj_a_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "m", m, m_cpy ); + computediff( "b", b, b_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "inca", inca, inca_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (A && size_A > 0) + { + computediff( "A", 'c', m, b, A, A_cpy, lda, true ); + delete[] A_cpy; + } + + if (x && size_x > 0) + { + computediff( "x", m, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif +} diff --git a/gtestsuite/testsuite/level1/axpyf/daxpyf_generic.cpp b/gtestsuite/testsuite/level1/axpyf/daxpyf_generic.cpp new file mode 100644 index 0000000000..37cd73eae0 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyf/daxpyf_generic.cpp @@ -0,0 +1,108 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyf.h" + +class daxpyfGeneric : + public ::testing::TestWithParam> {}; +// Tests using random integers as vector elements. +TEST_P( daxpyfGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + char conj_a = std::get<1>(GetParam()); + gtint_t m = std::get<2>(GetParam()); + gtint_t b = std::get<3>(GetParam()); + T alpha = std::get<4>(GetParam()); + + // stride size for x: + gtint_t inca = std::get<5>(GetParam()); + // stride size for y: + gtint_t lda = std::get<6>(GetParam()); + gtint_t incx = std::get<7>(GetParam()); + gtint_t incy = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyf.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = (b+1)*testinghelpers::getEpsilon(); + else + thresh = (2*b+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyf( conj_x, conj_a, m, b, &alpha, inca, lda, incx, incy, thresh ); +} + +// Black box testing for generic and main use of daxpy. +INSTANTIATE_TEST_SUITE_P( + FunctionalTest, + daxpyfGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of matrix + ::testing::Range(gtint_t(6), gtint_t(10), 1), // b size of matrix + ::testing::Values(double(0.0), double(1.0), double(2.3)), // alpha + ::testing::Values(gtint_t(0)), // lda increment + ::testing::Values(gtint_t(1)), // stride size for a + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::axpyfGenericPrint() + ); + diff --git a/gtestsuite/testsuite/level1/axpyf/test_axpyf.h b/gtestsuite/testsuite/level1/axpyf/test_axpyf.h new file mode 100644 index 0000000000..fc13f981a7 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyf/test_axpyf.h @@ -0,0 +1,128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "axpyf.h" +#include "level1/ref_axpyf.h" +#include "inc/check_error.h" + +/** + * axpyf operation is defined as : + * y := y + alpha * conja(A) * conjx(x) + * where A is an m x b matrix, and y and x are vectors. + * Matrix should be represented as "A" instead of "a" to distinguish it from vector. +*/ +template +static void test_axpyf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + gtint_t inca, + gtint_t lda_inc, + gtint_t incx, + gtint_t incy, + double thresh + ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + + // Compute the leading dimensions of A matrix. + gtint_t lda = testinghelpers::get_leading_dimension( 'c', 'n', m, b, lda_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector A = testinghelpers::get_random_matrix( -2, 8, 'c', 'n', m, b, lda ); + + std::vector x = testinghelpers::get_random_vector( -10, 10, m, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, m, incy ); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + // char, char, long, long, double, double*, long, long, double*, long, double*, long) + testinghelpers::ref_axpyf( conj_a, conj_x, m, b, alpha, A.data(), inca, lda, x.data(), incx, y_ref.data(), incy ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + axpyf( conj_a, conj_x, m, b, alpha, A.data(), inca, lda, x.data(), incx, y.data(), incy ); + + //--------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", m, y.data(), y_ref.data(), incy, thresh, true ); +} + + +// Test-case logger : Used to print the test-case details +template +class axpyfGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conja = std::get<0>(str.param); + char conjx = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t b = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t incx = std::get<7>(str.param); + gtint_t incy = std::get<8>(str.param); + + std::string str_name = "bli_"; + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_b_" + std::to_string(b); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/axpyv/IIT_ERS/axpyv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/axpyv/IIT_ERS/axpyv_IIT_ERS.cpp new file mode 100644 index 0000000000..4a983d8016 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/IIT_ERS/axpyv_IIT_ERS.cpp @@ -0,0 +1,204 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level1/axpyv/axpyv.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class axpyv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS/CBLAS calls for AXPY +TYPED_TEST_SUITE(axpyv_IIT_ERS, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) +/* + Early Return Scenarios(ERS) for BLAS/CBLAS compliance : + + The AXPY API is expected to return early in the following cases: + 1. When n <= 0 (BLAS compliance). + 2. When alpha = 0 (BLAS compliance). +*/ + +// Early return cases with non-unit strides on vectors +// When n < 0 +TYPED_TEST(axpyv_IIT_ERS, n_lt_zero_nonUnitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, -1, alpha, nullptr, 5, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, -1, alpha, x.data(), 5, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// When n = 0 +TYPED_TEST(axpyv_IIT_ERS, n_eq_zero_nonUnitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, 0, alpha, nullptr, 5, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, 0, alpha, x.data(), 5, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// When alpha = 0 +TYPED_TEST(axpyv_IIT_ERS, alpha_eq_zero_nonUnitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initzero( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, N, alpha, nullptr, 5, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, N, alpha, x.data(), 5, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// Early return cases with unit strides on vectors +// When n < 0 +TYPED_TEST(axpyv_IIT_ERS, n_lt_zero_unitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, -1, alpha, nullptr, 1, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, -1, alpha, x.data(), 1, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} + +// When n = 0 +TYPED_TEST(axpyv_IIT_ERS, n_eq_zero_unitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initone( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, 0, alpha, nullptr, 1, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, 0, alpha, x.data(), 1, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} + +// When alpha = 0 +TYPED_TEST(axpyv_IIT_ERS, alpha_eq_zero_unitStrides) +{ + using T = TypeParam; + T alpha; + testinghelpers::initzero( alpha ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + axpyv( CONJ, N, alpha, nullptr, 1, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + axpyv( CONJ, N, alpha, x.data(), 1, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} +#endif + diff --git a/gtestsuite/testsuite/level1/axpyv/axpyv.h b/gtestsuite/testsuite/level1/axpyv/axpyv.h index 10e56cae15..c4c1355369 100644 --- a/gtestsuite/testsuite/level1/axpyv/axpyv.h +++ b/gtestsuite/testsuite/level1/axpyv/axpyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -65,6 +66,21 @@ static void axpyv_(gtint_t n, T alpha, T* x, gtint_t incx, T* y, gtint_t incy) throw std::runtime_error("Error in testsuite/level1/axpyv.h: Invalid typename in axpyv_()."); } +template +static void axpyv_blis_impl(gtint_t n, T alpha, T* x, gtint_t incx, T* y, gtint_t incy) +{ + if constexpr (std::is_same::value) + saxpy_blis_impl( &n, &alpha, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + daxpy_blis_impl( &n, &alpha, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + caxpy_blis_impl( &n, &alpha, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + zaxpy_blis_impl( &n, &alpha, x, &incx, y, &incy ); + else + throw std::runtime_error("Error in testsuite/level1/axpyv.h: Invalid typename in axpyv_blis_impl()."); +} + template static void cblas_axpyv(gtint_t n, T alpha, T* x, gtint_t incx, T* y, gtint_t incy) { @@ -101,8 +117,33 @@ static void typed_axpyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T* template static void axpyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_x_cpy = conj_x; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS axpyv_( n, alpha, x, incx, y, incy ); +#elif TEST_BLAS_BLIS_IMPL + axpyv_blis_impl( n, alpha, x, incx, y, incy ); #elif TEST_CBLAS cblas_axpyv( n, alpha, x, incx, y, incy ); #elif TEST_BLIS_TYPED @@ -110,4 +151,26 @@ static void axpyv(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T* y, gti #else throw std::runtime_error("Error in testsuite/level1/axpyv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/caxpyv/caxpyv_generic.cpp similarity index 66% rename from gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp rename to gtestsuite/testsuite/level1/axpyv/caxpyv/caxpyv_generic.cpp index ad4db3c95b..cad418a3ee 100644 --- a/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/caxpyv/caxpyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -33,16 +33,16 @@ */ #include -#include "test_axpyv.h" +#include "level1/axpyv/test_axpyv.h" -class caxpyvGenericTest : +class caxpyvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( caxpyvGenericTest, RandomData ) +TEST_P( caxpyvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -61,7 +61,19 @@ TEST_P( caxpyvGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite axpyv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -69,43 +81,10 @@ TEST_P( caxpyvGenericTest, RandomData ) test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class caxpyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - scomplex alpha = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "caxpy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_caxpy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_caxpyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - // Black box testing for generic and main use of caxpy. INSTANTIATE_TEST_SUITE_P( Blackbox, - caxpyvGenericTest, + caxpyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED @@ -117,7 +96,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha ), - ::caxpyvGenericTestPrint() + ::axpyvGenericPrint() ); // Test for non-unit increments. @@ -125,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - caxpyvGenericTest, + caxpyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED @@ -137,7 +116,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3)), // stride size for y ::testing::Values(scomplex{4.0, 3.1}) // alpha ), - ::caxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -146,7 +125,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - caxpyvGenericTest, + caxpyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -154,6 +133,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(-3)), // stride size for y ::testing::Values(scomplex{4.0, 3.1}) // alpha ), - ::caxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_evt.cpp b/gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_evt.cpp new file mode 100644 index 0000000000..ad284bd29a --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_evt.cpp @@ -0,0 +1,413 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/axpyv/test_axpyv.h" + +class daxpyvEVT : + public ::testing::TestWithParam> {}; // alpha +// Tests using random values as vector elements, +// with exception values on the passed indices. +TEST_P( daxpyvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv(conj_x, n, incx, incy, alpha, xi, xexval, + yj, yexval, thresh); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors(Zen3) : + DAXPBY currently uses the bli_daxpyv_zen_int10( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure for bli_daxpyv_zen_int10( ... ) : + Main loop : In blocks of 52 --> L52 + Fringe loops : In blocks of 40 --> L40 + In blocks of 20 --> L20 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For size 535 : L52*10 + L8 + L4 + 3(LScalar) + Indices are : 0, 519 -> In L52 + 527 -> In L8 + 531 -> In L4 + 534 -> In LScalar + + + For size 556 : L52*10 + L20 + L16 + Indices are : 0, 519 -> In L52 + 539 -> In L20 + 555 -> In L16 + + + For size 560 : L52*10 + L40 + Indices are : 0, 519 -> In L52 + 559 -> In L40 + + The alpha values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides_zen3, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(535), gtint_t(556), gtint_t(560)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(519), gtint_t(527), + gtint_t(531), gtint_t(534), gtint_t(539), + gtint_t(555), gtint_t(559)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(double(0.0)), // dummy value on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides_zen3, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(535), gtint_t(556), gtint_t(560)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(double(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(519), gtint_t(527), + gtint_t(531), gtint_t(534), gtint_t(539), + gtint_t(555), gtint_t(559)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides_zen3, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(535), gtint_t(556), gtint_t(560)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(519), gtint_t(527), + gtint_t(531), gtint_t(534), gtint_t(539), + gtint_t(555), gtint_t(559)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(519), gtint_t(527), + gtint_t(531), gtint_t(534), gtint_t(539), + gtint_t(555), gtint_t(559)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +/* + Exception value testing on vectors(Zen4) : + DAXPY currently uses the bli_daxpyv_zen_int_avx512( ... ) kernel for computation on zen4 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure for bli_daxpyv_zen_int_avx512( ... ) : + Main loop : In blocks of 64 --> L52 + Fringe loops : In blocks of 32 --> L40 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For size 383 : L64*5 + L32 + L16 + L8 + L4 + 3(LScalar) + Indices are : 0, 319 -> In L64 + 351 -> In L32 + 367 -> In L16 + 375 -> In L8 + 379 -> In L4 + 382 -> In LScalar + + The alpha values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides_zen4, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(383)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(319), gtint_t(351), + gtint_t(367), gtint_t(375), gtint_t(379), + gtint_t(382)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(double(0.0)), // dummy value on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides_zen4, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(383)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(double(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(319), gtint_t(351), + gtint_t(367), gtint_t(375), gtint_t(379), + gtint_t(382)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides_zen4, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(383)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(319), gtint_t(351), + gtint_t(367), gtint_t(375), gtint_t(379), + gtint_t(382)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(319), gtint_t(351), + gtint_t(367), gtint_t(375), gtint_t(379), + gtint_t(382)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf, 2.9), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf, -1.5), // exception values to set on y + ::testing::Values(double(0.0), double(1.0), double(-1.0), double(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +/* + Exception value testing on alpha : + Alpha values are set to Nan, +Inf or -Inf. A dummy + value of 0.0 is induced in X and Y vectors, to further + verify the propagation. + + The size(s) for _zen3 and _zen4 instantiators are chosen such + that code coverage is ensured in the respective kernels. +*/ +INSTANTIATE_TEST_SUITE_P( + alpha_unitStrides_zen3, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(535), gtint_t(556), gtint_t(560)), // n, size of vectors with unit strides + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(double(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on alpha) with unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alpha_unitStrides_zen4, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(383)), // n, size of vectors with unit strides + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(double(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on alpha) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStrides, + daxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(25)), // indices to set zero on x + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(double(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); diff --git a/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_generic.cpp similarity index 51% rename from gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp rename to gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_generic.cpp index 10c1daefa2..fbfc816afc 100644 --- a/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/daxpyv/daxpyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -33,18 +33,18 @@ */ #include -#include "test_axpyv.h" +#include "level1/axpyv/test_axpyv.h" -class saxpyvGenericTest : +class daxpyvGeneric : public ::testing::TestWithParam> {}; + double>> {}; // Tests using random integers as vector elements. -TEST_P( saxpyvGenericTest, RandomData ) +TEST_P( daxpyvGeneric, API ) { - using T = float; + using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -61,7 +61,18 @@ TEST_P( saxpyvGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite axpyv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -69,50 +80,19 @@ TEST_P( saxpyvGenericTest, RandomData ) test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class saxpyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - float alpha = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "saxpy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_saxpy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_saxpyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of saxpy. +// Black box testing for generic and main use of daxpy. INSTANTIATE_TEST_SUITE_P( Blackbox, - saxpyvGenericTest, + daxpyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0), float(-2.0)) // alpha + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha ), - ::saxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #ifdef TEST_BLIS_TYPED @@ -121,15 +101,16 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - saxpyvGenericTest, + daxpyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0)) // alpha + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha ), - ::saxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #endif @@ -137,16 +118,17 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - saxpyvGenericTest, + nonUnitPositiveIncrements, + daxpyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2)), // stride size for x ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(float(4.0)) // alpha + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha ), - ::saxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -154,15 +136,64 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - saxpyvGenericTest, + negativeIncrements, + daxpyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-4)), // stride size for x ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(4.0) // alpha + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha + ), + ::axpyvGenericPrint() + ); +#endif + +// The following instantiator is enabled only if BLIS has been configured for openmp +// with aocl-dynamic enabled. +#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC) +// Checking for the thresholds with unit strides +INSTANTIATE_TEST_SUITE_P( + aoclDynamicThresholds_unitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(// Sizes are based on the thresholds + gtint_t(4000), // nt_ideal = 1 + gtint_t(11000), // nt_ideal = 4 + gtint_t(300000), // nt_ideal = 8 + gtint_t(750000), // nt_ideal = 16 + gtint_t(2600000), // nt_ideal = 32 + gtint_t(4000000)), // nt_ideal = 64 + + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha + ), + ::axpyvGenericPrint() + ); + +// Checking for the thresholds with non-unit strides +INSTANTIATE_TEST_SUITE_P( + aoclDynamicThresholds_nonUnitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(// Sizes are based on the thresholds + gtint_t(4000), // nt_ideal = 1 + gtint_t(11000), // nt_ideal = 4 + gtint_t(300000), // nt_ideal = 8 + gtint_t(750000), // nt_ideal = 16 + gtint_t(2600000), // nt_ideal = 32 + gtint_t(4000000)), // nt_ideal = 64 + + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(0.0), double(1.0), + double(-1.0), double(4.1)) // alpha ), - ::saxpyvGenericTestPrint() + ::axpyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp deleted file mode 100644 index 19d65ed5a3..0000000000 --- a/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp +++ /dev/null @@ -1,168 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_axpyv.h" - -class daxpyvGenericTest : - public ::testing::TestWithParam> {}; -// Tests using random integers as vector elements. -TEST_P( daxpyvGenericTest, RandomData ) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether x or conj(x) will be added to y: - char conj_x = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_axpyv( conj_x, n, incx, incy, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class daxpyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - double alpha = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "daxpy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_daxpy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_daxpyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of caxpy. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - daxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, not conj(x) (since it is real) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0), double(-2.0)) // alpha - ), - ::daxpyvGenericTestPrint() - ); - -#ifdef TEST_BLIS_TYPED -// Test when conjugate of x is used as an argument. This option is BLIS-api specific. -// Only test very few cases as sanity check since conj(x) = x for real types. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - ConjX, - daxpyvGenericTest, - ::testing::Combine( - ::testing::Values('c'), // c: use conj(x) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0)) // alpha - ), - ::daxpyvGenericTestPrint() - ); -#endif - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - daxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, not conj(x) (since it is real) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x - ::testing::Values(gtint_t(3)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(double(4.0)) // beta - ), - ::daxpyvGenericTestPrint() - ); - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - daxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-4)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(4.0) // alpha - ), - ::daxpyvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_evt.cpp b/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_evt.cpp new file mode 100644 index 0000000000..d385edcce0 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_evt.cpp @@ -0,0 +1,406 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/axpyv/test_axpyv.h" + +class saxpyvEVT : + public ::testing::TestWithParam> {}; // alpha + +// Tests using random values as vector elements, +// with exception values on the passed indices. +TEST_P( saxpyvEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // stride size for y + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + + // Set the threshold for the errors: + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv(conj_x, n, incx, incy, alpha, xi, xexval, + yj, yexval, thresh); +} + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors(Zen3) : + SAXPY currently uses the bli_saxpyv_zen_int10( ... ) kernel for computation on zen3 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure for bli_saxpyv_zen_int10( ... ) : + Main loop : In blocks of 120 --> L120 + Fringe loops : In blocks of 80 --> L80 + In blocks of 40 --> L40 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For size 471 : L120*3 + L80 + L16 + 8 + 7(LScalar) + Indices are : 0, 359 -> In L120 + 439 -> In L80 + 455 -> In L16 + 463 -> In L8 + 470 -> In LScalar + + For size 432 : L120*3 + L40 + L32 + Indices are : 0, 359 -> In L52 + 399 -> In L40 + 431 -> In L32 + + The alpha values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides_zen3, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(432), gtint_t(471)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(359), gtint_t(399), + gtint_t(431), gtint_t(439), gtint_t(455), + gtint_t(463), gtint_t(470)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(float(0.0)), // dummy value on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides_zen3, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(432), gtint_t(471)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(float(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(359), gtint_t(399), + gtint_t(431), gtint_t(439), gtint_t(455), + gtint_t(463), gtint_t(470)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides_zen3, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(432), gtint_t(471)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(359), gtint_t(399), + gtint_t(431), gtint_t(439), gtint_t(455), + gtint_t(463), gtint_t(470)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(359), gtint_t(399), + gtint_t(431), gtint_t(439), gtint_t(455), + gtint_t(463), gtint_t(470)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +/* + Exception value testing on vectors(Zen4) : + SAXPY currently uses the bli_saxpyv_zen_int_avx512( ... ) kernel for computation on zen4 + machines. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure for bli_saxpyv_zen_int_avx512( ... ) : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For size 767 : L128*5 + L64 + L32 + + L16 + L8 + 7(LScalar) + Indices are : 0, 639 -> In L128 + 703 -> In L64 + 734 -> In L32 + 751 -> In L16 + 759 -> In L8 + 766 -> In LScalar + + The alpha values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + as a result of 0.0 * { NaN, +Inf, -Inf }. +*/ +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides_zen4, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(767)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(639), gtint_t(703), + gtint_t(734), gtint_t(751), gtint_t(759), + gtint_t(766)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(float(0.0)), // dummy value on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides_zen4, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(767)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(float(0.0)), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(639), gtint_t(703), + gtint_t(734), gtint_t(751), gtint_t(759), + gtint_t(766)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides_zen4, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(767)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(639), gtint_t(703), + gtint_t(734), gtint_t(751), gtint_t(759), + gtint_t(766)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(639), gtint_t(703), + gtint_t(734), gtint_t(751), gtint_t(759), + gtint_t(766)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(NaN, -Inf, Inf, 2.9), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(NaN, -Inf, Inf, -1.5), // exception values to set on y + ::testing::Values(float(0.0), float(1.0), float(-1.0), float(-3.3)) // alpha + ), + ::axpyvEVTPrint()); + +/* + Exception value testing on alpha : + Alpha values are set to Nan, +Inf or -Inf. A dummy + value of 0.0 is induced in X and Y vectors, to further + verify the propagation. + + The size(s) for _zen3 and _zen4 instantiators are chosen such + that code coverage is ensured in the respective kernels. +*/ +INSTANTIATE_TEST_SUITE_P( + alpha_unitStrides_zen3, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(432), gtint_t(471)), // n, size of vectors with unit strides + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(float(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on alpha) with unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alpha_unitStrides_zen4, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(767)), // n, size of vectors with unit strides + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(float(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on alpha) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStrides, + saxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(25)), // indices to set zero on x + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(float(0.0)), + ::testing::Values(NaN, -Inf, Inf) // alpha + ), + ::axpyvEVTPrint()); diff --git a/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_generic.cpp new file mode 100644 index 0000000000..6592843276 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/saxpyv/saxpyv_generic.cpp @@ -0,0 +1,228 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/axpyv/test_axpyv.h" + +class saxpyvGeneric : + public ::testing::TestWithParam> {}; // alpha +// Tests using random integers as vector elements. +TEST_P( saxpyvGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // stride size for y + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); +} + +// Black box testing for generic and main use of saxpy. +INSTANTIATE_TEST_SUITE_P( + unitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(2.0), float(-2.0)) // alpha + ), + ::axpyvGenericPrint() + ); + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + ConjX, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('c'), // c: use conj(x) + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(2.5), float(1.0), + float(-1.0), float(0.0)) // alpha + ), + ::axpyvGenericPrint() + ); +#endif + +// Test for non-unit increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(float(2.5), float(1.0), + float(-1.0), float(0.0)) // alpha + ), + ::axpyvGenericPrint() + ); + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + negativeStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-4)), // stride size for x + ::testing::Values(gtint_t(-3)), // stride size for y + ::testing::Values(float(2.5), float(1.0), + float(-1.0), float(0.0)) // alpha + ), + ::axpyvGenericPrint() + ); +#endif +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + differentSizesOfM, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(264), // M size of the vector + gtint_t(1600), + gtint_t(1992), + gtint_t(744), + gtint_t(3264), + gtint_t(2599), + gtint_t(4800), + gtint_t(2232), + gtint_t(2080), + gtint_t(1764), + gtint_t(622), + gtint_t(128), + gtint_t(64), + gtint_t(32), + gtint_t(16)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(2.0), + float(0.0), + float(1.0), + float(-1.0)) // alpha + ), + ::axpyvGenericPrint() + ); +//increment values of x and y are zero +INSTANTIATE_TEST_SUITE_P( + zeroIncrements, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(113)), // m size of vector + ::testing::Values(gtint_t(0),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(2),gtint_t(0)), // stride size for y + ::testing::Values(float(2.0), + float(0.0), + float(1.0), + float(-1.0)) // alpha + ), + ::axpyvGenericPrint() + ); +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1000)), // m size of vector + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(2)), // stride size for y + ::testing::Values(float(2.0), + float(0.0), + float(1.0), + float(-1.0)) // alpha + ), + ::axpyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(10)), // m size of vector + ::testing::Values(gtint_t(20)), // stride size for x + ::testing::Values(gtint_t(33)), // stride size for y + ::testing::Values(float(2.0), + float(0.0), + float(1.0), + float(-1.0)) // alpha + ), + ::axpyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/axpyv/test_axpyv.h b/gtestsuite/testsuite/level1/axpyv/test_axpyv.h index 1cc375da00..f366c0bd81 100644 --- a/gtestsuite/testsuite/level1/axpyv/test_axpyv.h +++ b/gtestsuite/testsuite/level1/axpyv/test_axpyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -67,5 +67,97 @@ static void test_axpyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } + +template +static void test_axpyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, + T alpha, gtint_t xi, T xexval, gtint_t yj, T yexval, + double thresh ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); + + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = xexval; + else return; + + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < yj && yj < n ) y[yj * abs(incy)] = yexval; + else return; + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + testinghelpers::ref_axpyv( conjx, n, alpha, x.data(), incx, y_ref.data(), incy ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + axpyv( conjx, n, alpha, x.data(), incx, y.data(), incy ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y.data(), y_ref.data(), incy, thresh, true ); +} + +// Test-case logger : Used to print the test-case details based on parameters +template +class axpyvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + return str_name; + } +}; + +template +class axpyvEVTPrint +{ +public: + std::string operator()( + testing::TestParamInfo> str) const + { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + gtint_t xi = std::get<4>(str.param); + T xexval = std::get<5>(str.param); + gtint_t yj = std::get<6>(str.param); + T yexval = std::get<7>(str.param); + T alpha = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + std::string xexval_str = testinghelpers::get_value_string(xexval); + std::string yexval_str = testinghelpers::get_value_string(yexval); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + xexval_str; + str_name = str_name + "_Y_" + std::to_string(yj); + str_name = str_name + "_" + yexval_str; + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_evt.cpp b/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_evt.cpp new file mode 100644 index 0000000000..3b5f072736 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_evt.cpp @@ -0,0 +1,310 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/axpyv/test_axpyv.h" + +class zaxpyvEVT : + public ::testing::TestWithParam> {}; // alpha + +// Tests using random values as vector elements, +// with exception values on the passed indices. +TEST_P( zaxpyvEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // stride size for y + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for x + T yexval = std::get<7>(GetParam()); + // alpha + T alpha = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Small adjustment has been applied for complex data. + double adj = 1.5; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = adj*2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv(conj_x, n, incx, incy, alpha, xi, xexval, + yj, yexval, thresh); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/* + Exception value testing on vectors : + SAXPY currently uses the bli_zaxpyv_zen_int5( ... ) kernel for computation. + The sizes and indices given in the instantiator are to ensure code coverage inside + the kernel, and to verify the compliance accordingly. + + Kernel structure for bli_zaxpyv_zen_int5( ... ) : + Main loop : In blocks of 14 --> L14 + Fringe loops : In blocks of 10 --> L10 + In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + The sizes chosen are as follows : + 52 - 3*L14 + L10 + 48 - 3*L14 + L6 + 46 - 3*L14 + L4 + 45 - 3*L14 + L2 + LScalar + + The following indices are sufficient to ensure code-coverage of loops + in these sizes : + 0, 41 - In L14 + 43 - In { L10, L6, L4, L2 }, based on the size + 44 - In { L10, L6, L4, LScalar }, based on the size + + The alpha values are such that they check for compliance against possible + optimizations that might have been done. + + P.S : Some test cases also check whether NaN has to be induced in the computation + such as 0.0 * { {NaN, 0}, {+Inf, 0}, {-Inf, 0}, ... }, and a few more. +*/ + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(45), gtint_t(46), + gtint_t(48), gtint_t(52)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(41), + gtint_t(43), gtint_t(44)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0)), // dummy index on y + ::testing::Values(dcomplex{0.0, 0.0}), // dummy value on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(45), gtint_t(46), + gtint_t(48), gtint_t(52)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // dummy index on x + ::testing::Values(dcomplex{0.0, 0.0}), // dummy value on x + ::testing::Values(gtint_t(0), gtint_t(41), + gtint_t(43), gtint_t(44)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(45), gtint_t(46), + gtint_t(48), gtint_t(52)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(41), + gtint_t(43), gtint_t(44)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(41), + gtint_t(43), gtint_t(44)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on vectors) with non-unit strides +// We have to test a single scalar loop. The indices are such +// that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), // indices to set exception values on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.5}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), // indices to set exception values on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{2.3, -3.5}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 1.0}, + dcomplex{0.0, -1.0}, dcomplex{-3.3, 1.7}) // alpha + ), + ::axpyvEVTPrint()); + +/* + Exception value testing on alpha : + Alpha values are set to Nan, +Inf or -Inf. A dummy + value of 0.0 is induced in X and Y vectors, to further + verify the propagation. + + The size(s) for _zen3 and _zen4 instantiators are chosen such + that code coverage is ensured in the respective kernels. +*/ +INSTANTIATE_TEST_SUITE_P( + alpha_unitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(45), gtint_t(46), + gtint_t(48), gtint_t(52)), // n, size of vectors with unit-stride + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set zero on x + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // indices to set zero on y + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}) // alpha + ), + ::axpyvEVTPrint()); + +// Exception value testing(on alpha) with non-unit strided vectors +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStrides, + zaxpyvEVT, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(50)), // n, size of vectors with non-unit strides + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(25)), // indices to set zero on x + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0), gtint_t(40)), // indices to set zero on y + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}) // alpha + ), + ::axpyvEVTPrint()); diff --git a/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_generic.cpp new file mode 100644 index 0000000000..7ac3a10be7 --- /dev/null +++ b/gtestsuite/testsuite/level1/axpyv/zaxpyv/zaxpyv_generic.cpp @@ -0,0 +1,218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/axpyv/test_axpyv.h" + +class zaxpyvGeneric : + public ::testing::TestWithParam> {}; // alpha +// Tests using random integers as vector elements. +TEST_P( zaxpyvGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // stride size for y + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With small adjustment applied for complex data. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.02; +#else + double adj = 1.0; +#endif + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = adj*2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); +} + +// Black box testing for generic and main use of zaxpy. +INSTANTIATE_TEST_SUITE_P( + unitStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{-3.7, 1.2}, dcomplex{1.5, 2.6}, + dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); + +// Test for non-unit increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(dcomplex{-3.7, 1.2}, dcomplex{1.5, 2.6}, + dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + negativeStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-4)), // stride size for x + ::testing::Values(gtint_t(-3)), // stride size for y + ::testing::Values(dcomplex{-3.7, 1.2}, dcomplex{1.5, 2.6}, + dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); +#endif +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + DifferentSizesOfM, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(36), //m size of vector + gtint_t(1000), + gtint_t(2999), + gtint_t(3666), + gtint_t(777)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{2.0, 1.1}, + dcomplex{0.0, 0.0}, + dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); +//incx and incy are zero. +INSTANTIATE_TEST_SUITE_P( + ZeroIncrements, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(10)), // m size of vector + ::testing::Values(gtint_t(0),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3),gtint_t(0)), // stride size for y + ::testing::Values(dcomplex{4.0, 3.1}, + dcomplex{0.0, 0.0}, + dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1000)), // m size of vector + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(dcomplex{4.0, 3.1}, + dcomplex{0.0, 0.0}, + dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(6)), // m size of vector + ::testing::Values(gtint_t(10)), // stride size for x + ::testing::Values(gtint_t(14)), // stride size for y + ::testing::Values(dcomplex{4.0, 3.1}, + dcomplex{0.0, 0.0}, + dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0}) // alpha + ), + ::axpyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp deleted file mode 100644 index 64b98f1b04..0000000000 --- a/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp +++ /dev/null @@ -1,158 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_axpyv.h" - -class zaxpyvGenericTest : - public ::testing::TestWithParam> {}; -// Tests using random integers as vector elements. -TEST_P( zaxpyvGenericTest, RandomData ) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether x or conj(x) will be added to y: - char conj_x = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_axpyv( conj_x, n, incx, incy, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zaxpyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - dcomplex alpha = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zaxpy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zaxpy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zaxpyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of zaxpy. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - zaxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(dcomplex{-3.0, 1.0}, dcomplex{1.0, 2.0}) // alpha - ), - ::zaxpyvGenericTestPrint() - ); - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - zaxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2)), // stride size for x - ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(dcomplex{-1.0, 2.0}) // alpha - ), - ::zaxpyvGenericTestPrint() - ); - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - zaxpyvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-4)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(dcomplex{4.0, 3.1}) // alpha - ), - ::zaxpyvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp index 29f988005b..3433a3deb3 100644 --- a/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,14 +35,14 @@ #include #include "test_copyv.h" -class ccopyvGenericTest : - public ::testing::TestWithParam> {}; +class ccopyvGeneric : + public ::testing::TestWithParam> {}; // stride size for y -// Tests using random integers as vector elements. -TEST_P( ccopyvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( ccopyvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -58,48 +58,16 @@ TEST_P( ccopyvGenericTest, RandomData ) // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv( conjx, n, incx, incy, thresh ); + test_copyv( conjx, n, incx, incy ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class ccopyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ccopy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ccopy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ccopyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of ccopy. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ccopyvGenericTest, + smallSize, + ccopyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED @@ -110,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::ccopyvGenericTestPrint() + ::copyvGenericPrint() ); // Test for non-unit increments. @@ -118,7 +86,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - ccopyvGenericTest, + ccopyvGeneric, ::testing::Combine( ::testing::Values('n' #ifdef TEST_BLIS_TYPED @@ -129,7 +97,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::ccopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -138,13 +106,64 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - ccopyvGenericTest, + ccopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), - ::ccopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + differentSizesOfM, + ccopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1760), + gtint_t(255), + gtint_t(1280), + gtint_t(64), + gtint_t(32), + gtint_t(16), + gtint_t(8), + gtint_t(1920), + gtint_t(2240), + gtint_t(5400), + gtint_t(2483), + gtint_t(184), + gtint_t(160), + gtint_t(1916), + gtint_t(908), + gtint_t(732)), // m size of vector + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::copyvGenericPrint() + ); +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + ccopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(3000)), // m size of vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2)) // stride size for y + ), + ::copyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + ccopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(3)), // m size of vector + ::testing::Values(gtint_t(55)), // stride size for x + ::testing::Values(gtint_t(66)) // stride size for y + ), + ::copyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/copyv/copyv.h b/gtestsuite/testsuite/level1/copyv/copyv.h index cc8bf85af0..f9947aea99 100644 --- a/gtestsuite/testsuite/level1/copyv/copyv.h +++ b/gtestsuite/testsuite/level1/copyv/copyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -64,6 +65,21 @@ static void copyv_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { throw std::runtime_error("Error in testsuite/level1/copyv.h: Invalid typename in copyv_()."); } +template +static void copyv_blis_impl(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { + + if constexpr (std::is_same::value) + scopy_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + dcopy_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + ccopy_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + zcopy_blis_impl( &n, x, &incx, y, &incy ); + else + throw std::runtime_error("Error in testsuite/level1/copyv.h: Invalid typename in copyv_blis_impl()."); +} + template static void cblas_copyv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { @@ -100,8 +116,32 @@ static void typed_copyv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t template static void copyv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS copyv_(n, x, incx, y, incy); +#elif TEST_BLAS_BLIS_IMPL + copyv_blis_impl(n, x, incx, y, incy); #elif TEST_CBLAS cblas_copyv(n, x, incx, y, incy); #elif TEST_BLIS_TYPED @@ -109,4 +149,25 @@ static void copyv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) #else throw std::runtime_error("Error in testsuite/level1/copyv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/copyv/copyv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/copyv/copyv_IIT_ERS.cpp new file mode 100644 index 0000000000..24418cb31e --- /dev/null +++ b/gtestsuite/testsuite/level1/copyv/copyv_IIT_ERS.cpp @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "copyv.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class copyv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS/CBLAS calls for COPYV +TYPED_TEST_SUITE(copyv_IIT_ERS, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) +/* + Early Return Scenarios(ERS) for BLAS/CBLAS compliance: + + The COPYV API is expected to return early in the following cases: + 1. When n <= 0. +*/ + +// Early return cases with non-unit strides on vectors +// When n < 0 +TYPED_TEST(copyv_IIT_ERS, n_lt_zero_nonUnitStrides) +{ + using T = TypeParam; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + copyv( CONJ, -1, nullptr, 5, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y_vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + copyv( CONJ, -1, x.data(), 5, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// When n = 0 +TYPED_TEST(copyv_IIT_ERS, n_eq_zero_nonUnitStrides) +{ + using T = TypeParam; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + copyv( CONJ, 0, nullptr, 5, nullptr, 5 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 5 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 5 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + copyv( CONJ, 0, x.data(), 5, y.data(), 5 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 5 ); +} + +// Early return cases with unit strides on vectors +// When n < 0 +TYPED_TEST(copyv_IIT_ERS, n_lt_zero_unitStrides) +{ + using T = TypeParam; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + copyv( CONJ, -1, nullptr, 1, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y_vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + copyv( CONJ, -1, x.data(), 1, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} + +// When n = 0 +TYPED_TEST(copyv_IIT_ERS, n_eq_zero_unitStrides) +{ + using T = TypeParam; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + copyv( CONJ, 0, nullptr, 1, nullptr, 1 ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + // Defining the y vector with values for debugging purposes + std::vector y = testinghelpers::get_random_vector( -10, 10, N, 1 ); + + // Copy so that we check that the elements of y are not modified. + std::vector y_ref(y); + + copyv( CONJ, 0, x.data(), 1, y.data(), 1 ); + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), 1 ); +} +#endif diff --git a/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp index 1c7824b8f4..5c7b219031 100644 --- a/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,14 +35,14 @@ #include #include "test_copyv.h" -class dcopyvGenericTest : - public ::testing::TestWithParam> {}; +class dcopyvGeneric : + public ::testing::TestWithParam> {}; // stride size for y -// Tests using random integers as vector elements. -TEST_P( dcopyvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( dcopyvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -58,55 +58,23 @@ TEST_P( dcopyvGenericTest, RandomData ) // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv( conjx, n, incx, incy, thresh ); + test_copyv( conjx, n, incx, incy ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class dcopyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dcopy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dcopy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dcopyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of scopy. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dcopyvGenericTest, + smallSize, + dcopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::dcopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifdef TEST_BLIS_TYPED // BLIS-api specific @@ -115,14 +83,14 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - dcopyvGenericTest, + dcopyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::dcopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif @@ -131,14 +99,14 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - dcopyvGenericTest, + dcopyvGeneric, ::testing::Combine( ::testing::Values('n'), // use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::dcopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -147,13 +115,65 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - dcopyvGenericTest, + dcopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), - ::dcopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + differentSizesOfM, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1270), + gtint_t(64), + gtint_t(32), + gtint_t(16), + gtint_t(8), + gtint_t(4), + gtint_t(960), + gtint_t(3120), + gtint_t(1900), + gtint_t(124), + gtint_t(880), + gtint_t(80), + gtint_t(256), + gtint_t(480), + gtint_t(788), + gtint_t(36), + gtint_t(24)), // m size of vector + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::copyvGenericPrint() + ); +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1000)), // m size of vector + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)) // stride size for y + ), + ::copyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + StrideGreaterThanSize, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(4)), // m size of vector + ::testing::Values(gtint_t(6)), // stride size for x + ::testing::Values(gtint_t(8)) // stride size for y + ), + ::copyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp index e86d2f320f..a5699af7ba 100644 --- a/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,14 +35,14 @@ #include #include "test_copyv.h" -class scopyvGenericTest : - public ::testing::TestWithParam> {}; +class scopyvGeneric : + public ::testing::TestWithParam> {}; // stride size for y -// Tests using random integers as vector elements. -TEST_P( scopyvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( scopyvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -58,55 +58,23 @@ TEST_P( scopyvGenericTest, RandomData ) // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv( conjx, n, incx, incy, thresh ); + test_copyv( conjx, n, incx, incy ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class scopyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "scopy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_scopy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_scopyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of scopyv. INSTANTIATE_TEST_SUITE_P( - Blackbox, - scopyvGenericTest, + smallSize, + scopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::scopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifdef TEST_BLIS_TYPED // BLIS-api specific @@ -115,14 +83,14 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - scopyvGenericTest, + scopyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::scopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif @@ -131,14 +99,14 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - scopyvGenericTest, + scopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::scopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -147,13 +115,65 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - scopyvGenericTest, + scopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), - ::scopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + differentSizesOfM, + scopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1270), + gtint_t(640), + gtint_t(32), + gtint_t(16), + gtint_t(8), + gtint_t(4), + gtint_t(960), + gtint_t(2120), + gtint_t(1000), + gtint_t(1724), + gtint_t(888), + gtint_t(680), + gtint_t(56), + gtint_t(48), + gtint_t(3033), + gtint_t(36), + gtint_t(24)), // m size of vector + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::copyvGenericPrint() + ); +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + scopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(2222)), // m size of vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2)) // stride size for y + ), + ::copyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + scopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(2)), // m size of vector + ::testing::Values(gtint_t(50)), // stride size for x + ::testing::Values(gtint_t(75)) // stride size for y + ), + ::copyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/copyv/test_copyv.h b/gtestsuite/testsuite/level1/copyv/test_copyv.h index 6ab5a12bca..f9c1b36eaa 100644 --- a/gtestsuite/testsuite/level1/copyv/test_copyv.h +++ b/gtestsuite/testsuite/level1/copyv/test_copyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -43,7 +43,7 @@ */ template -static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh ) +static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy ) { //---------------------------------------------------------- // Initialize vectors with random numbers. @@ -67,5 +67,24 @@ static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, doubl //---------------------------------------------------------- // Compute error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy ); + computediff( "y", n, y.data(), y_ref.data(), incy ); } + +// Test-case logger : Used to print the test-case details based on parameters +class copyvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp index eeb9b13e37..839f5a142b 100644 --- a/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,14 +35,14 @@ #include #include "test_copyv.h" -class zcopyvGenericTest : - public ::testing::TestWithParam> {}; +class zcopyvGeneric : + public ::testing::TestWithParam> {}; // stride size for y -// Tests using random integers as vector elements. -TEST_P( zcopyvGenericTest, RandomData ) +// Tests using random values as vector elements. +TEST_P( zcopyvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -58,48 +58,16 @@ TEST_P( zcopyvGenericTest, RandomData ) // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv( conjx, n, incx, incy, thresh ); + test_copyv( conjx, n, incx, incy ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zcopyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zcopy_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zcopy"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zcopyv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of zcopy. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zcopyvGenericTest, + smallSize, + zcopyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -110,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::zcopyvGenericTestPrint() + ::copyvGenericPrint() ); // Test for non-unit increments. @@ -118,7 +86,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - zcopyvGenericTest, + zcopyvGeneric, ::testing::Combine( ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED @@ -129,7 +97,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::zcopyvGenericTestPrint() + ::copyvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -138,13 +106,58 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - zcopyvGenericTest, + zcopyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), - ::zcopyvGenericTestPrint() + ::copyvGenericPrint() ); #endif +//To cover large sizes with non unit increments. +INSTANTIATE_TEST_SUITE_P( + largeSize, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(4444)), // m size of vector + ::testing::Values(gtint_t(4)), // stride size for x + ::testing::Values(gtint_t(3)) // stride size for y + ), + ::copyvGenericPrint() + ); +// To cover small, medium and large sizes of M with unit increment. +INSTANTIATE_TEST_SUITE_P( + DiffSizeOfM, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(1250), + gtint_t(4200), + gtint_t(3344), + gtint_t(2244), + gtint_t(32), + gtint_t(64), + gtint_t(128), + gtint_t(264), + gtint_t(987), + gtint_t(1876)), // m size of vector + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::copyvGenericPrint() + ); +//incx and incy is greater than size of a vector m. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values(gtint_t(4)), // m size of vector + ::testing::Values(gtint_t(88)), // stride size for x + ::testing::Values(gtint_t(99)) // stride size for y + ), + ::copyvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp index 0a662d96b4..289db862c1 100644 --- a/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_dotv.h" -class cdotvGenericTest : +class cdotvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( cdotvGenericTest, RandomData ) +TEST_P( cdotvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -62,7 +62,15 @@ TEST_P( cdotvGenericTest, RandomData ) gtint_t incy = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*n*testinghelpers::getEpsilon(); + // Check gtestsuite dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -70,47 +78,12 @@ TEST_P( cdotvGenericTest, RandomData ) test_dotv( conjx, conjy, n, incx, incy, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class cdotvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - char conjy = std::get<1>(str.param); - gtint_t n = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cdotu_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cdotu_sub"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cdotv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of cdot. INSTANTIATE_TEST_SUITE_P( Blackbox, - cdotvGenericTest, + cdotvGeneric, ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) + ::testing::Values('n', 'c'), // 'n': tests cdotu_, 'c': tests cdotc_ ::testing::Values('n' #ifdef TEST_BLIS_TYPED , 'c' // this option is BLIS-api specific. @@ -120,7 +93,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::cdotvGenericTestPrint() + ::dotvGenericPrint() ); // Test for non-unit increments. @@ -128,13 +101,9 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - cdotvGenericTest, + cdotvGeneric, ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) + ::testing::Values('n', 'c'), // 'n': tests cdotu_, 'c': tests cdotc_ ::testing::Values('n' #ifdef TEST_BLIS_TYPED , 'c' // this option is BLIS-api specific. @@ -144,7 +113,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3)) // stride size for y ), - ::cdotvGenericTestPrint() + ::dotvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -153,14 +122,14 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - cdotvGenericTest, + cdotvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Values('n', 'c'), // 'n': tests cdotu_, 'c': tests cdotc_ ::testing::Values('n'), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-2)), // stride size for x ::testing::Values(gtint_t(-3)) // stride size for y ), - ::cdotvGenericTestPrint() + ::dotvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotv/ddotv_evt.cpp b/gtestsuite/testsuite/level1/dotv/ddotv_evt.cpp new file mode 100644 index 0000000000..cb9eef5d3e --- /dev/null +++ b/gtestsuite/testsuite/level1/dotv/ddotv_evt.cpp @@ -0,0 +1,473 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_dotv.h" + +class ddotvEVT : + public ::testing::TestWithParam> {}; // yexval + +// Tests using random integers as vector elements. +TEST_P( ddotvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether vec x is n,c + char conjx = std::get<0>(GetParam()); + // denotes whether vec y is n,c + char conjy = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // index of extreme value for x: + gtint_t xi = std::get<4>(GetParam()); + // extreme value for x: + double x_exval = std::get<5>(GetParam()); + // stride size for y: + gtint_t incy = std::get<6>(GetParam()); + // index of extreme value for y: + gtint_t yi = std::get<7>(GetParam()); + // extreme value for y: + double y_exval = std::get<8>(GetParam()); + + // Set the threshold for the errors: + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_dotv( conjx, conjy, n, incx, xi, x_exval, incy, yi, y_exval, thresh ); +} + + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Tests for Zen4 Architecture. +/** + * bli_ddotv_zen_int_avx512( ... ) + * Loops: + * L40 - Main loop, handles 40 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * LScalar - leftover loop + * + * n = 109 : L40*2 + L16 + L8 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 79 - L40 + * 93 - L16 + * 101 - L8 + * 108 - LScalar + */ +// EVT with unit stride X vector containing Infs/NaNs. +// Unit stride Y vector contains random elements. +INSTANTIATE_TEST_SUITE_P( + vecX_unitStride_zen4, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(109) + ), + // incx: stride of x vector. + ::testing::Values(gtint_t(1)), // unit stride + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(79), gtint_t(93), + gtint_t(101), gtint_t(108) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values(gtint_t(1)), // unit stride + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0) ), // set as 0 since testing only for x + // y_exval: extreme value for y. + ::testing::Values( double(0.0) ) // dummy value since testing only for x + ), + ::dotvEVTPrint() + ); + + +// EVT with unit stride Y vector containing Infs/NaNs. +// Unit stride X vector contains random elements. +INSTANTIATE_TEST_SUITE_P( + vecY_unitStride_zen4, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(109) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), // unit stride + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), // set as 0 since testing only for y + // x_exval: extreme value for x. + ::testing::Values( double(0.0) ), // dummy value since testing only for y + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), // unit stride + // yi: index of extreme value for y. + ::testing::Values( + gtint_t(0), gtint_t(79), gtint_t(93), + gtint_t(101), gtint_t(108) + ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); + +// EVT with unit stride vectors X and Y contatining Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStride_zen4, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(109) + ), + // incx: stride of x vector. + ::testing::Values(gtint_t(1)), // unit stride + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(79), gtint_t(93), + gtint_t(101), gtint_t(108) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values(gtint_t(1)), // unit stride + // yi: index of extreme value for y. + ::testing::Values( + gtint_t(0), gtint_t(79), gtint_t(93), + gtint_t(101), gtint_t(108) + ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); + +// Tests for Zen3 Architecture. +/** + * bli_ddotv_zen_int10( ... ) + * Loops: + * L40 - Main loop, handles 40 elements + * L20 - handles 20 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * LScalar - leftover loop + * + * n = 119 : L40*2 + L20 + L16 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 78 - L40 + * 94 - L20 + * 101, 110 - L16 + * 112 - L16 + * 118 - LScalar + * + * n = 113 : L40*2 + L20 + L8 + L4 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 78 - L40 + * 94 - L20 + * 101 - L8 + * 110 - L4 + * 112 - LScalar + */ +// EVT with unit stride X vector containing Infs/NaNs. +// Unit stride Y vector contains random elements. +INSTANTIATE_TEST_SUITE_P( + vecX_unitStride_zen3, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(119), + gtint_t(113) + ), + // incx: stride of x vector. + ::testing::Values(gtint_t(1)), // unit stride + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(78), gtint_t(94), + gtint_t(101), gtint_t(110), gtint_t(112), + gtint_t(118) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values(gtint_t(1)), // unit stride + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0) ), // set as 0 since testing only for x + // y_exval: extreme value for y. + ::testing::Values( double(0.0) ) // dummy value since testing only for x + ), + ::dotvEVTPrint() + ); + +// EVT with unit stride Y vector containing Infs/NaNs. +// Unit stride X vector contains random elements. +INSTANTIATE_TEST_SUITE_P( + vecY_unitStride_zen3, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(119), + gtint_t(113) + ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), // unit stride + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), // set as 0 since testing only for y + // x_exval: extreme value for x. + ::testing::Values( double(0.0) ), // dummy value since testing only for y + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), // unit stride + // yi: index of extreme value for y. + ::testing::Values( + gtint_t(0), gtint_t(78), gtint_t(94), + gtint_t(110), gtint_t(118) + ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); + +// EVT with unit stride vectors X and Y contatining Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStride_zen3, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(119), + gtint_t(115) + ), + // incx: stride of x vector. + ::testing::Values(gtint_t(1)), // unit stride + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(79), gtint_t(93), + gtint_t(101), gtint_t(108) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values(gtint_t(1)), // unit stride + // yi: index of extreme value for y. + ::testing::Values( + gtint_t(0), gtint_t(78), gtint_t(94), + gtint_t(110), gtint_t(118) + ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); + +// EVT with non-unit stride vectors X and Y containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStride, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values( gtint_t(7) ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(3), gtint_t(29), gtint_t(47) ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); + +// EVT with negative stride vectors X and Y containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vecXY_negativeStride, + ddotvEVT, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(-3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // incy: stride of y vector. + ::testing::Values( gtint_t(-7) ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(3), gtint_t(29), gtint_t(47) ), + // y_exval: extreme value for y. + ::testing::Values( NaN, Inf, -Inf ) + ), + ::dotvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp index 5af449fb32..d664e89195 100644 --- a/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_dotv.h" -class ddotvGenericTest : +class ddotvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( ddotvGenericTest, RandomData ) +TEST_P( ddotvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -62,7 +62,14 @@ TEST_P( ddotvGenericTest, RandomData ) gtint_t incy = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -70,49 +77,23 @@ TEST_P( ddotvGenericTest, RandomData ) test_dotv( conjx, conjy, n, incx, incy, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class ddotvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - char conjy = std::get<1>(str.param); - gtint_t n = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ddot_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ddot"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ddotv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - -// Black box testing for generic and main use of sdot. +// Black box testing for generic use of ddot. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ddotvGenericTest, + unitPositiveStride, + ddotvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: use x, not conj(x) (since it is real) - ::testing::Values('n'), // n: use y, not conj(y) (since it is real) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)) // stride size for y + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(10), gtint_t(101), 10), + // incx: stride of x vector. + ::testing::Values(gtint_t(1)), // unit stride + // incy: stride of y vector. + ::testing::Values(gtint_t(1)) // unit stride ), - ::ddotvGenericTestPrint() + ::dotvGenericPrint() ); #ifdef TEST_BLIS_TYPED // BLIS-api specific @@ -121,7 +102,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - ddotvGenericTest, + ddotvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values('c'), // c: use conj(y) @@ -129,7 +110,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::ddotvGenericTestPrint() + ::dotvGenericPrint() ); #endif @@ -137,16 +118,25 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - ddotvGenericTest, + nonUnitPositiveStrides, + ddotvGeneric, ::testing::Combine( - ::testing::Values('n'), // use x, not conj(x) (since it is real) - ::testing::Values('n'), // use y, not conj(y) (since it is real) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(10), gtint_t(101), 10), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit positive strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit positive strides for sanity check + ) ), - ::ddotvGenericTestPrint() + ::dotvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -154,15 +144,55 @@ INSTANTIATE_TEST_SUITE_P( // Only test very few cases as sanity check. // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - ddotvGenericTest, + negativeStrides, + ddotvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Values('n'), // n: use y, c: use conj(y) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)) // stride size for y + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(10), gtint_t(101), 10), + // incx: stride of x vector. + ::testing::Values( + gtint_t(-1), gtint_t(-3), gtint_t(-7) // few non-unit negative strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(-1), gtint_t(-3), gtint_t(-7) // few non-unit negative strides for sanity check + ) + ), + ::dotvGenericPrint() + ); +#endif + +#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC) +INSTANTIATE_TEST_SUITE_P( + AOCLDynamicThresholds, + ddotvGeneric, + ::testing::Combine( + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 2500), // nt_ideal = 1 + gtint_t( 5000), // nt_ideal = 4 + gtint_t( 15000), // nt_ideal = 8 + gtint_t( 40000), // nt_ideal = 16 + gtint_t(200000), // nt_ideal = 32 + gtint_t(250000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ) ), - ::ddotvGenericTestPrint() + ::dotvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotv/dotv.h b/gtestsuite/testsuite/level1/dotv/dotv.h index 7917868e56..a2424dfece 100644 --- a/gtestsuite/testsuite/level1/dotv/dotv.h +++ b/gtestsuite/testsuite/level1/dotv/dotv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -53,7 +54,6 @@ template static void dotv_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { - if constexpr (std::is_same::value) *rho = sdot_(&n, x, &incx, y, &incy); else if constexpr (std::is_same::value) @@ -74,19 +74,128 @@ static void dotv_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotv_()."); } +template +static void dotu_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotu_(&n, x, &incx, y, &incy); + #else + cdotu_(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotu_(&n, x, &incx, y, &incy); + #else + zdotu_(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotu_()."); +} + +template +static void dotc_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotc_(&n, x, &incx, y, &incy); + #else + cdotc_(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotc_(&n, x, &incx, y, &incy); + #else + zdotc_(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotc_()."); +} + +template +static void dotv_blis_impl(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + *rho = sdot_blis_impl(&n, x, &incx, y, &incy); + else if constexpr (std::is_same::value) + *rho = ddot_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotu_blis_impl(&n, x, &incx, y, &incy); + #else + cdotu_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotu_blis_impl(&n, x, &incx, y, &incy); + #else + zdotu_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotv_blis_impl()."); +} + +template +static void dotu_blis_impl(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotu_blis_impl(&n, x, &incx, y, &incy); + #else + cdotu_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotu_blis_impl(&n, x, &incx, y, &incy); + #else + zdotu_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotu_blis_impl()."); +} + +template +static void dotc_blis_impl(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotc_blis_impl(&n, x, &incx, y, &incy); + #else + cdotc_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotc_blis_impl(&n, x, &incx, y, &incy); + #else + zdotc_blis_impl(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotc_blis_impl()."); +} + template static void cblas_dotv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + *rho = cblas_sdot( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + *rho = cblas_ddot( n, x, incx, y, incy ); + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in cblas_dotv()."); +} - if constexpr (std::is_same::value) - *rho = cblas_sdot( n, x, incx, y, incy ); - else if constexpr (std::is_same::value) - *rho = cblas_ddot( n, x, incx, y, incy ); - else if constexpr (std::is_same::value) - cblas_cdotu_sub( n, x, incx, y, incy, rho ); - else if constexpr (std::is_same::value) - cblas_zdotu_sub( n, x, incx, y, incy, rho ); - else - throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in cblas_dotv()."); +template +static void cblas_dotu(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + cblas_cdotu_sub( n, x, incx, y, incy, rho ); + else if constexpr (std::is_same::value) + cblas_zdotu_sub( n, x, incx, y, incy, rho ); + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in cblas_dotu()."); +} + +template +static void cblas_dotc(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + if constexpr (std::is_same::value) + cblas_cdotc_sub( n, x, incx, y, incy, rho ); + else if constexpr (std::is_same::value) + cblas_zdotc_sub( n, x, incx, y, incy, rho ); + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in cblas_dotc()."); } template @@ -113,13 +222,97 @@ template static void dotv(char conjx, char conjy, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); + conjy = static_cast(std::toupper(static_cast(conjy))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + char conjy_cpy = conjy; + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } + T* y_cpy = nullptr; + gtint_t size_y = testinghelpers::buff_dim( n, incy ); + if (y && size_y > 0) + { + y_cpy = new T[size_y]; + memcpy( y_cpy, y, size_y * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS - dotv_(n, x, incx, y, incy, rho); + if constexpr ( testinghelpers::type_info::is_real ) + dotv_(n, x, incx, y, incy, rho); + else if constexpr ( testinghelpers::type_info::is_complex ) + { + if ( testinghelpers::chkconj(conjx) ) + dotc_(n, x, incx, y, incy, rho); + else + dotu_(n, x, incx, y, incy, rho); + } +#elif TEST_BLAS_BLIS_IMPL + if constexpr ( testinghelpers::type_info::is_real ) + dotv_blis_impl(n, x, incx, y, incy, rho); + else if constexpr ( testinghelpers::type_info::is_complex ) + { + if ( testinghelpers::chkconj(conjx) ) + dotc_blis_impl(n, x, incx, y, incy, rho); + else + dotu_blis_impl(n, x, incx, y, incy, rho); + } #elif TEST_CBLAS - cblas_dotv(n, x, incx, y, incy, rho); + if constexpr ( testinghelpers::type_info::is_real ) + cblas_dotv(n, x, incx, y, incy, rho); + else if constexpr ( testinghelpers::type_info::is_complex ) + { + if ( testinghelpers::chkconj(conjx) ) + cblas_dotc(n, x, incx, y, incy, rho); + else + cblas_dotu(n, x, incx, y, incy, rho); + } #elif TEST_BLIS_TYPED typed_dotv(conjx, conjy, n, x, incx, y, incy, rho); #else throw std::runtime_error("Error in testsuite/level1/dotv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "conjy", conjy, conjy_cpy ); + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } + if (y && size_y > 0) + { + computediff( "y", n, y, y_cpy, incy, true ); + delete[] y_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/dotv/dotv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/dotv/dotv_IIT_ERS.cpp new file mode 100644 index 0000000000..f8a3739d8e --- /dev/null +++ b/gtestsuite/testsuite/level1/dotv/dotv_IIT_ERS.cpp @@ -0,0 +1,172 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_dotv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class dotv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(dotv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + BLAS Early Return Scenarios(ERS): + + DOTV is expected to return early in the following cases: + 1. n <= 0 +*/ + +// n < 0, with non-unit stride +TYPED_TEST(dotv_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 5; + // Initialize rho (BLIS output) to garbage value. + T rho = T{-7.3}; + // Initialize the expected output to zero. + T rho_ref; + testinghelpers::initzero(rho_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + dotv( CONJ, CONJ, invalid_n, nullptr, inc, nullptr, inc, &rho ); + // Computing the difference. + computediff( "rho", rho, rho_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking DOTV with an invalid value of n. + dotv( CONJ, CONJ, invalid_n, x.data(), inc, y.data(), inc, &rho ); + + // Computing the difference. + computediff( "rho", rho, rho_ref ); +} + +// n == 0, with non-unit stride +TYPED_TEST(dotv_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 5; + // Initialize rho (BLIS output) to garbage value. + T rho = T{-7.3}; + // Initialize the expected output to zero. + T rho_ref; + testinghelpers::initzero(rho_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + dotv( CONJ, CONJ, invalid_n, nullptr, inc, nullptr, inc, &rho ); + // Computing the difference. + computediff( "rho", rho, rho_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking DOTV with an invalid value of n. + dotv( CONJ, CONJ, invalid_n, x.data(), inc, y.data(), inc, &rho ); + + // Computing the difference. + computediff( "rho", rho, rho_ref ); +} + +// n < 0, with unit stride +TYPED_TEST(dotv_IIT_ERS, n_lt_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t unit_inc = 1; + // Initialize rho (BLIS output) to garbage value. + T rho = T{-7.3}; + // Initialize the expected output to zero. + T rho_ref; + testinghelpers::initzero(rho_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + dotv( CONJ, CONJ, invalid_n, nullptr, unit_inc, nullptr, unit_inc, &rho ); + // Computing the difference. + computediff( "rho", rho, rho_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking DOTV with an invalid value of n. + dotv( CONJ, CONJ, invalid_n, x.data(), unit_inc, y.data(), unit_inc, &rho ); + + // Computing the difference. + computediff( "rho", rho, rho_ref ); +} + +// n == 0, with unit stride +TYPED_TEST(dotv_IIT_ERS, n_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t unit_inc = 1; + // Initialize rho (BLIS output) to garbage value. + T rho = T{-7.3}; + // Initialize the expected output to zero. + T rho_ref; + testinghelpers::initzero(rho_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + dotv( CONJ, CONJ, invalid_n, nullptr, unit_inc, nullptr, unit_inc, &rho ); + // Computing the difference. + computediff( "rho", rho, rho_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize vectors with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking DOTV with an invalid value of n. + dotv( CONJ, CONJ, invalid_n, x.data(), unit_inc, y.data(), unit_inc, &rho ); + + // Computing the difference. + computediff( "rho", rho, rho_ref ); +} +#endif diff --git a/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp index 9d69ac6e7a..3ef5f7ba7f 100644 --- a/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_dotv.h" -class sdotvGenericTest : +class sdotvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( sdotvGenericTest, RandomData ) +TEST_P( sdotvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -62,7 +62,14 @@ TEST_P( sdotvGenericTest, RandomData ) gtint_t incy = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -70,41 +77,10 @@ TEST_P( sdotvGenericTest, RandomData ) test_dotv( conjx, conjy, n, incx, incy, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class sdotvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - char conjy = std::get<1>(str.param); - gtint_t n = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sdot_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sdot"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sdotv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of sdotv. INSTANTIATE_TEST_SUITE_P( Blackbox, - sdotvGenericTest, + sdotvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values('n'), // n: use y, not conj(y) (since it is real) @@ -112,7 +88,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::sdotvGenericTestPrint() + ::dotvGenericPrint() ); #ifdef TEST_BLIS_TYPED // BLIS-api specific @@ -121,7 +97,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - sdotvGenericTest, + sdotvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values('c'), // c: use conj(y) @@ -129,7 +105,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)) // stride size for y ), - ::sdotvGenericTestPrint() + ::dotvGenericPrint() ); #endif @@ -138,7 +114,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - sdotvGenericTest, + sdotvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values('n'), // n: use y, not conj(y) (since it is real) @@ -146,7 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::sdotvGenericTestPrint() + ::dotvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -155,7 +131,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - sdotvGenericTest, + sdotvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values('n'), // n: use y, c: use conj(y) @@ -163,6 +139,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(-2)), // stride size for x ::testing::Values(gtint_t(-3)) // stride size for y ), - ::sdotvGenericTestPrint() + ::dotvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotv/test_dotv.h b/gtestsuite/testsuite/level1/dotv/test_dotv.h index 3f9610f7da..d0864853cc 100644 --- a/gtestsuite/testsuite/level1/dotv/test_dotv.h +++ b/gtestsuite/testsuite/level1/dotv/test_dotv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -72,5 +72,104 @@ static void test_dotv( char conjx, char conjy, gtint_t n, gtint_t incx, //---------------------------------------------------------- // Compute error. //---------------------------------------------------------- - computediff( rho, rho_ref, thresh ); + computediff( "rho", rho, rho_ref, thresh ); } + + +/** + * @brief Used to insert Exception Values in vectors x and y. + */ +template +static void test_dotv( char conjx, char conjy, gtint_t n, + gtint_t incx, gtint_t xi, double x_exval, + gtint_t incy, gtint_t yi, double y_exval, + double thresh ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); + + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = x_exval; + else return; + + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < yi && yi < n ) y[yi * abs(incy)] = y_exval; + else return; + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + T rho_ref; + if constexpr (testinghelpers::type_info::is_real) + testinghelpers::ref_dotv( n, x.data(), incx, y_ref.data(), incy, &rho_ref ); + else + testinghelpers::ref_dotv( conjx, conjy, n, x.data(), incx, y_ref.data(), incy, &rho_ref ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + T rho; + dotv( conjx, conjy, n, x.data(), incx, y.data(), incy, &rho ); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + computediff( "rho", rho, rho_ref, thresh, true); +} + + +// Test-case logger : Used to print the test-case details based on parameters +class dotvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + char conjy = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; + +template +class dotvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + char conjy = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t xi = std::get<4>(str.param); + T x_exval = std::get<5>(str.param); + gtint_t incy = std::get<6>(str.param); + gtint_t yi = std::get<7>(str.param); + T y_exval = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + testinghelpers::get_value_string(x_exval); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name = str_name + "_Y_" + std::to_string(yi); + str_name = str_name + "_" + testinghelpers::get_value_string(y_exval); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp index 7d7d3aabd0..82d3aabeae 100644 --- a/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_dotv.h" -class zdotvGenericTest : +class zdotvGeneric : public ::testing::TestWithParam> {}; // Tests using random integers as vector elements. -TEST_P( zdotvGenericTest, RandomData ) +TEST_P( zdotvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -62,7 +62,15 @@ TEST_P( zdotvGenericTest, RandomData ) gtint_t incy = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*n*testinghelpers::getEpsilon(); + // Check gtestsuite dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -70,57 +78,22 @@ TEST_P( zdotvGenericTest, RandomData ) test_dotv( conjx, conjy, n, incx, incy, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zdotvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conjx = std::get<0>(str.param); - char conjy = std::get<1>(str.param); - gtint_t n = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zdotu_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zdotu_sub"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zdotv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - // Black box testing for generic and main use of zdot. INSTANTIATE_TEST_SUITE_P( Blackbox, - zdotvGenericTest, + zdotvGeneric, ::testing::Combine( + ::testing::Values('n', 'c'), // 'n': tests zdotu_, 'c': tests zdotc_ ::testing::Values('n' #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , 'c' // this option is BLIS-api specific. #endif - ), // n: use x, c: use conj(x) - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use y, c: use conj(y) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)) // stride size for y + ), // n: use y, c: use conj(y) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)) // stride size for y ), - ::zdotvGenericTestPrint() + ::dotvGenericPrint() ); // Test for non-unit increments. @@ -128,23 +101,19 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitPositiveIncrements, - zdotvGenericTest, + zdotvGeneric, ::testing::Combine( + ::testing::Values('n', 'c'), // 'n': tests zdotu_, 'c': tests zdotc_ ::testing::Values('n' #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , 'c' // this option is BLIS-api specific. #endif - ), // n: use y, c: use conj(y) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y + ), // n: use y, c: use conj(y) + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector + ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), - ::zdotvGenericTestPrint() + ::dotvGenericPrint() ); #ifndef TEST_BLIS_TYPED @@ -153,14 +122,45 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NegativeIncrements, - zdotvGenericTest, + zdotvGeneric, + ::testing::Combine( + ::testing::Values('n', 'c'), // 'n': tests zdotu_, 'c': tests zdotc_ + ::testing::Values('n'), // n: use y, c: use conj(y) + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector + ::testing::Values(gtint_t(-2)), // stride size for x + ::testing::Values(gtint_t(-3)) // stride size for y + ), + ::dotvGenericPrint() + ); +#endif + +#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC) +INSTANTIATE_TEST_SUITE_P( + AOCLDynamicThresholds, + zdotvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Values('n'), // n: use y, c: use conj(y) - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)) // stride size for y + // conj(x): user n (no_conjugate) since it is real. + ::testing::Values('n', 'c'), + // conj(y): user n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 2080), // nt_ideal = 1 + gtint_t( 3328), // nt_ideal = 4 + gtint_t( 98304), // nt_ideal = 8 + gtint_t(262144), // nt_ideal = 32 + gtint_t(524288), // nt_ideal = 64 + gtint_t(550000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ) ), - ::zdotvGenericTestPrint() + ::dotvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotxf/ddotxf_generic.cpp b/gtestsuite/testsuite/level1/dotxf/ddotxf_generic.cpp new file mode 100644 index 0000000000..71589d317b --- /dev/null +++ b/gtestsuite/testsuite/level1/dotxf/ddotxf_generic.cpp @@ -0,0 +1,147 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_dotxf.h" + +class ddotxfGeneric : + public ::testing::TestWithParam> {}; +// Tests using random integers as vector elements. +TEST_P( ddotxfGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be used + char conj_x = std::get<0>(GetParam()); + // denotes whether A or conj(A) will be used + char conj_a = std::get<1>(GetParam()); + // matrix size m + gtint_t m = std::get<2>(GetParam()); + // matrix size n + gtint_t b = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); + // lda increment for A + gtint_t lda_inc = std::get<5>(GetParam()); + // stride size for A + gtint_t inca = std::get<6>(GetParam()); + // stride size for x + gtint_t incx = std::get<7>(GetParam()); + // beta + T beta = std::get<8>(GetParam()); + // stride size for y + gtint_t incy = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite dotxf.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Threshold adjustment + if (m == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (alpha == testinghelpers::ONE()) + if (beta == testinghelpers::ZERO()) + thresh = (m)*testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + { +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 3.9; +#endif + thresh = adj*(m+1)*testinghelpers::getEpsilon(); + } + else + thresh = (m+2)*testinghelpers::getEpsilon(); + else + if (beta == testinghelpers::ZERO()) + thresh = (2*m)*testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + { +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 5.2; +#endif + thresh = adj*(2*m+1)*testinghelpers::getEpsilon(); + } + else + { + thresh = (2*m+2)*testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_dotxf( conj_x, conj_a, m, b, &alpha, inca, lda_inc, incx, &beta, incy, thresh ); +} + +// Black box testing for generic and main use of ddotxf. +INSTANTIATE_TEST_SUITE_P( + FunctionalTest, + ddotxfGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Values('n'), // n: use x, not conj(x) (since it is real) + ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of matrix + ::testing::Range(gtint_t(6), gtint_t(10), 1), // b size of matrix + ::testing::Values(double(0.0), double(1.0), double(2.3)), // alpha + ::testing::Values(gtint_t(0)), // lda increment + ::testing::Values(gtint_t(1)), // stride size for a + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(double(1.0)), // beta + ::testing::Values(gtint_t(1)) // stride size for y + ), + ::dotxfGenericPrint() + ); + diff --git a/gtestsuite/testsuite/level1/dotxf/dotxf.h b/gtestsuite/testsuite/level1/dotxf/dotxf.h new file mode 100644 index 0000000000..3e5bba3c22 --- /dev/null +++ b/gtestsuite/testsuite/level1/dotxf/dotxf.h @@ -0,0 +1,179 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +static void typed_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T *beta, + T* y, + gtint_t incy) +{ + conj_t conja; + conj_t conjx; + // Map parameter characters to BLIS constants. + testinghelpers::char_to_blis_conj( conj_a, &conja ); + testinghelpers::char_to_blis_conj( conj_x, &conjx ); + if constexpr (std::is_same::value) + bli_sdotxf(conja, conjx, m, b, alpha, A, inca, lda, x, incx, beta, y, incy); + else if constexpr (std::is_same::value) + bli_ddotxf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, beta, y, incy ); + else if constexpr (std::is_same::value) + bli_cdotxf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, beta, y, incy ); + else if constexpr (std::is_same::value) + bli_zdotxf( conja, conjx, m, b, alpha, A, inca, lda, x, incx, beta, y, incy ); + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in typed_dotv()."); +} + +template +static void dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + T* A, + gtint_t inca, + gtint_t lda, + T* x, + gtint_t incx, + T *beta, + T* y, + gtint_t incy +) +{ + +#ifdef TEST_UPPERCASE_ARGS + conj_a = static_cast(std::toupper(static_cast(conj_a))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_a_cpy = conj_a; + char conj_x_cpy = conj_x; + gtint_t m_cpy = m; + gtint_t b_cpy = b; + T* alpha_cpy = alpha; + gtint_t inca_cpy = inca; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + T* beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* A_cpy = nullptr; + gtint_t size_A = testinghelpers::matsize( 'c', 'n', m, b, lda ); + if (A && size_A > 0) + { + A_cpy = new T[size_A]; + memcpy( A_cpy, A, size_A * sizeof( T ) ); + } + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( m, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + +/** + * dotxf operation is defined as : + * y := beta * y + alpha * conja(A) * conjx(x) + * where A is an m x b matrix, and y and x are vectors. + */ + typed_dotxf( + conj_a, + conj_x, + m, + b, + alpha, + A, + inca, + lda, + x, + incx, + beta, + y, + incy ); + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_a", conj_a, conj_a_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "m", m, m_cpy ); + computediff( "b", b, b_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "inca", inca, inca_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (A && size_A > 0) + { + computediff( "A", 'c', m, b, A, A_cpy, lda, true ); + delete[] A_cpy; + } + + if (x && size_x > 0) + { + computediff( "x", m, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif +} diff --git a/gtestsuite/testsuite/level1/dotxf/test_dotxf.h b/gtestsuite/testsuite/level1/dotxf/test_dotxf.h new file mode 100644 index 0000000000..ff0024a575 --- /dev/null +++ b/gtestsuite/testsuite/level1/dotxf/test_dotxf.h @@ -0,0 +1,128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "dotxf.h" +#include "level1/ref_dotxf.h" +#include "inc/check_error.h" + + +template +static void test_dotxf( + char conj_a, + char conj_x, + gtint_t m, + gtint_t b, + T *alpha, + gtint_t inca, + gtint_t lda_inc, + gtint_t incx, + T *beta, + gtint_t incy, + double thresh + ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( 'c', 'n', m, b, lda_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector A = testinghelpers::get_random_matrix( -2, 8, 'c', 'n', m, b, lda ); + + std::vector x = testinghelpers::get_random_vector( -10, 10, m, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, b, incy ); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + + testinghelpers::ref_dotxf( conj_a, conj_x, m, b, alpha, A.data(), inca, lda, x.data(), incx, beta, y_ref.data(), incy ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + dotxf( conj_a, conj_x, m, b, alpha, A.data(), inca, lda, x.data(), incx, beta, y.data(), incy ); + + //--------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", b, y.data(), y_ref.data(), incy, thresh, true ); +} + + +// Test-case logger : Used to print the test-case details +template +class dotxfGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conja = std::get<0>(str.param); + char conjx = std::get<1>(str.param); + gtint_t m = std::get<2>(str.param); + gtint_t b = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t incx = std::get<7>(str.param); + T beta = std::get<8>(str.param); + gtint_t incy = std::get<9>(str.param); + + std::string str_name = "bli_"; + + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_b_" + std::to_string(b); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp index 5ed6f67d96..6acdfef72d 100644 --- a/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,13 +35,13 @@ #include #include "test_dotxv.h" -class cdotxvGenericTest : +class cdotxvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cdotxvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cdotxvGeneric); // Tests using random integers as vector elements. -TEST_P( cdotxvGenericTest, RandomData ) +TEST_P( cdotxvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -64,7 +64,40 @@ TEST_P( cdotxvGenericTest, RandomData ) T beta = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + { + // Like SCALV (for one element) + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ZERO()) + { + // Like DOTV but with alpha scaling + if (alpha == testinghelpers::ONE()) + thresh = (2*n)*testinghelpers::getEpsilon(); + else + thresh = (3*n)*testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + if (alpha == testinghelpers::ONE()) + thresh = (2*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); + } + else if (alpha == testinghelpers::ONE()) + thresh = (2*n+2)*testinghelpers::getEpsilon(); + else + thresh = (3*n+2)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,44 +105,11 @@ TEST_P( cdotxvGenericTest, RandomData ) test_dotxv( n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class cdotxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - scomplex beta = std::get<6>(str.param); - std::string str_name = "bli_cdotxv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of cdotxv. INSTANTIATE_TEST_SUITE_P( Blackbox, - cdotxvGenericTest, + cdotxvGeneric, ::testing::Combine( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values('n', 'c'), // n: use x, c: use conj(x) @@ -119,13 +119,13 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0, -1.0}), // alpha ::testing::Values(scomplex{-1.0, 1.0}) // beta ), - ::cdotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Black box testing for generic and main use of cdotxv. INSTANTIATE_TEST_SUITE_P( SmallSizesBlackbox, - cdotxvGenericTest, + cdotxvGeneric, ::testing::Combine( ::testing::Range(gtint_t(1), gtint_t(11), 1), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values('n', 'c'), // n: use x, c: use conj(x) @@ -135,7 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0, -1.0}), // alpha ::testing::Values(scomplex{-1.0, 1.0}) // beta ), - ::cdotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test for non-unit increments. @@ -143,7 +143,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - cdotxvGenericTest, + cdotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('n', 'c'), // n: use x, c: use conj(x) @@ -153,6 +153,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0, -1.0}), // alpha ::testing::Values(scomplex{-1.0, 1.0}) // beta ), - ::cdotxvGenericTestPrint() + ::dotxvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp index 75376ed4b9..16fef2c28a 100644 --- a/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,13 +35,13 @@ #include #include "test_dotxv.h" -class ddotxvGenericTest : +class ddotxvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ddotxvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ddotxvGeneric); // Tests using random integers as vector elements. -TEST_P( ddotxvGenericTest, RandomData ) +TEST_P( ddotxvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -64,7 +64,39 @@ TEST_P( ddotxvGenericTest, RandomData ) T beta = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + { + // Like SCALV (for one element) + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ZERO()) + { + // Like DOTV but with alpha scaling + if (alpha == testinghelpers::ONE()) + thresh = (2*n)*testinghelpers::getEpsilon(); + else + thresh = (3*n)*testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + if (alpha == testinghelpers::ONE()) + thresh = (2*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); + } + else if (alpha == testinghelpers::ONE()) + thresh = (2*n+2)*testinghelpers::getEpsilon(); + else + thresh = (3*n+2)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,42 +104,11 @@ TEST_P( ddotxvGenericTest, RandomData ) test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class ddotxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - double beta = std::get<6>(str.param); - std::string str_name = "bli_ddotxv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of ddotxv. INSTANTIATE_TEST_SUITE_P( Blackbox, - ddotxvGenericTest, + ddotxvGeneric, ::testing::Combine( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values('n'), // n: use x, not conj(x) (since it is real) @@ -117,7 +118,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::ddotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test when conjugate of x is used as an argument. @@ -125,7 +126,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - ddotxvGenericTest, + ddotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('c'), // use x, not conj(x) (since it is real) @@ -135,7 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::ddotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test for non-unit increments. @@ -143,7 +144,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - ddotxvGenericTest, + ddotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('n'), // use x, not conj(x) (since it is real) @@ -153,6 +154,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::ddotxvGenericTestPrint() + ::dotxvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotxv/dotxv.h b/gtestsuite/testsuite/level1/dotxv/dotxv.h index 3bb01ad0a0..c6092637e3 100644 --- a/gtestsuite/testsuite/level1/dotxv/dotxv.h +++ b/gtestsuite/testsuite/level1/dotxv/dotxv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -76,8 +77,36 @@ template static void dotxv( char conjx, char conjy, gtint_t n, T* alpha, T* x, gtint_t incx, T* y, gtint_t incy, T* beta, T* rho ) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); + conjy = static_cast(std::toupper(static_cast(conjy))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + char conjy_cpy = conjy; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + T* beta_cpy = beta; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/dotxv.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/dotxv.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS throw std::runtime_error("Error in testsuite/level1/dotxv.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED @@ -85,4 +114,28 @@ static void dotxv( char conjx, char conjy, gtint_t n, T* alpha, #else throw std::runtime_error("Error in testsuite/level1/dotxv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "conjy", conjy, conjy_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp index 9ee47c18a7..35568778eb 100644 --- a/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,13 +35,13 @@ #include #include "test_dotxv.h" -class sdotxvGenericTest : +class sdotxvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sdotxvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sdotxvGeneric); // Tests using random integers as vector elements. -TEST_P( sdotxvGenericTest, RandomData ) +TEST_P( sdotxvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -64,7 +64,39 @@ TEST_P( sdotxvGenericTest, RandomData ) T beta = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + { + // Like SCALV (for one element) + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ZERO()) + { + // Like DOTV but with alpha scaling + if (alpha == testinghelpers::ONE()) + thresh = (2*n)*testinghelpers::getEpsilon(); + else + thresh = (3*n)*testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + if (alpha == testinghelpers::ONE()) + thresh = (2*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); + } + else if (alpha == testinghelpers::ONE()) + thresh = (2*n+2)*testinghelpers::getEpsilon(); + else + thresh = (3*n+2)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,42 +104,11 @@ TEST_P( sdotxvGenericTest, RandomData ) test_dotxv( n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class sdotxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - float beta = std::get<6>(str.param); - std::string str_name = "bli_sdotxv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of sdotxv. INSTANTIATE_TEST_SUITE_P( Blackbox, - sdotxvGenericTest, + sdotxvGeneric, ::testing::Combine( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values('n'), // n: use x, not conj(x) (since it is real) @@ -117,7 +118,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::sdotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test when conjugate of x is used as an argument. @@ -125,7 +126,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - sdotxvGenericTest, + sdotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('c'), // c: use conj(x) @@ -135,7 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::sdotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test for non-unit increments. @@ -143,7 +144,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - sdotxvGenericTest, + sdotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('n'), // n: use x, not conj(x) (since it is real) @@ -153,6 +154,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, 2.0), // alpha ::testing::Values(2.0, 3.0) // beta ), - ::sdotxvGenericTestPrint() + ::dotxvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/dotxv/test_dotxv.h b/gtestsuite/testsuite/level1/dotxv/test_dotxv.h index 729e172b8f..a885d92ab5 100644 --- a/gtestsuite/testsuite/level1/dotxv/test_dotxv.h +++ b/gtestsuite/testsuite/level1/dotxv/test_dotxv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -71,5 +71,31 @@ static void test_dotxv( gtint_t n, char conjx, char conjy, T alpha, //---------------------------------------------------------- // Compute error. //---------------------------------------------------------- - computediff( rho, rho_ref, thresh ); + computediff( "rho", rho, rho_ref, thresh ); } + +// Test-case logger : Used to print the test-case details based on parameters +template +class dotxvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + char conjx = std::get<1>(str.param); + char conjy = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp index 10bfcac45f..4245225a0f 100644 --- a/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,13 +35,13 @@ #include #include "test_dotxv.h" -class zdotxvGenericTest : +class zdotxvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdotxvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdotxvGeneric); // Tests using random integers as vector elements. -TEST_P( zdotxvGenericTest, RandomData ) +TEST_P( zdotxvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -64,7 +64,40 @@ TEST_P( zdotxvGenericTest, RandomData ) T beta = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = n*testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + { + // Like SCALV (for one element) + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ZERO()) + { + // Like DOTV but with alpha scaling + if (alpha == testinghelpers::ONE()) + thresh = (2*n)*testinghelpers::getEpsilon(); + else + thresh = (3*n)*testinghelpers::getEpsilon(); + } + else if (beta == testinghelpers::ONE()) + { + if (alpha == testinghelpers::ONE()) + thresh = (2*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); + } + else if (alpha == testinghelpers::ONE()) + thresh = (2*n+2)*testinghelpers::getEpsilon(); + else + thresh = (3*n+2)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,44 +105,11 @@ TEST_P( zdotxvGenericTest, RandomData ) test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zdotxvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t incx = std::get<3>(str.param); - gtint_t incy = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - dcomplex beta = std::get<6>(str.param); - std::string str_name = "bli_zdotxv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conjx, 1); - str_name += "_" + std::string(&conjy, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of zdotxv. INSTANTIATE_TEST_SUITE_P( Blackbox, - zdotxvGenericTest, + zdotxvGeneric, ::testing::Combine( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values('n', 'c'), // n: use x, c: use conj(x) @@ -119,7 +119,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0, -1.0}), // alpha ::testing::Values(dcomplex{-1.0, 1.0}) // beta ), - ::zdotxvGenericTestPrint() + ::dotxvGenericPrint() ); // Test for non-unit increments. @@ -127,7 +127,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - zdotxvGenericTest, + zdotxvGeneric, ::testing::Combine( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values('n', 'c'), // n: use x, c: use conj(x) @@ -137,6 +137,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0, -1.0}), // alpha ::testing::Values(dcomplex{-1.0, 1.0}) // beta ), - ::zdotxvGenericTestPrint() + ::dotxvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp index e9c1d53189..075ff8e114 100644 --- a/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_scal2v.h" -class cscal2vGenericTest : +class cscal2vGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cscal2vGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cscal2vGeneric); // Tests using random integers as vector elements. -TEST_P( cscal2vGenericTest, RandomData ) +TEST_P( cscal2vGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -64,44 +64,29 @@ TEST_P( cscal2vGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class cscal2vGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - scomplex alpha = std::get<4>(str.param); - std::string str_name = "bli_cscal2v"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of cscal2. INSTANTIATE_TEST_SUITE_P( Blackbox, - cscal2vGenericTest, + cscal2vGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -109,7 +94,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha ), - ::cscal2vGenericTestPrint() + ::scal2vGenericPrint() ); @@ -118,7 +103,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - cscal2vGenericTest, + cscal2vGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -126,6 +111,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(4)), // stride size for y ::testing::Values(scomplex{4.0, 3.1}) // alpha ), - ::cscal2vGenericTestPrint() + ::scal2vGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp index 66b624c382..dbde70eaf1 100644 --- a/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_scal2v.h" -class dscal2vGenericTest : +class dscal2vGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dscal2vGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dscal2vGeneric); // Tests using random integers as vector elements. -TEST_P( dscal2vGenericTest, RandomData ) +TEST_P( dscal2vGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -64,43 +64,28 @@ TEST_P( dscal2vGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class dscal2vGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - double alpha = std::get<4>(str.param); - std::string str_name = "bli_dscal2v"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of dscal2. INSTANTIATE_TEST_SUITE_P( Blackbox, - dscal2vGenericTest, + dscal2vGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -108,7 +93,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(double(2.0), double(-3.0)) // alpha ), - ::dscal2vGenericTestPrint() + ::scal2vGenericPrint() ); // Test when conjugate of x is used as an argument. This option is BLIS-api specific. @@ -116,7 +101,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( Conjalpha, - dscal2vGenericTest, + dscal2vGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conjugate ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. @@ -124,7 +109,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(double(-3.0)) // alpha ), - ::dscal2vGenericTestPrint() + ::scal2vGenericPrint() ); // Test for non-unit increments. @@ -132,7 +117,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - dscal2vGenericTest, + dscal2vGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. @@ -140,6 +125,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(5)), // stride size for y ::testing::Values(double(3.0)) // alpha ), - ::dscal2vGenericTestPrint() + ::scal2vGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/scal2v/scal2v.h b/gtestsuite/testsuite/level1/scal2v/scal2v.h index ad1383b712..1afe6ac546 100644 --- a/gtestsuite/testsuite/level1/scal2v/scal2v.h +++ b/gtestsuite/testsuite/level1/scal2v/scal2v.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -71,13 +72,60 @@ static void typed_scal2v(char conj_x, gtint_t n, T alpha, T* x, gtint_t incx, T* template static void scal2v(char conjx, gtint_t n, T alpha, T* x, gtint_t incx, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + gtint_t n_cpy = n; + T alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/scal2v.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/scal2v.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS - throw std::runtime_error("Error in testsuite/level1/scal2v.h: BLAS interface is not available."); + throw std::runtime_error("Error in testsuite/level1/scal2v.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED typed_scal2v( conjx, n, alpha, x, incx, y, incy ); #else throw std::runtime_error("Error in testsuite/level1/scal2v.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp index 366d649ead..67bef674f3 100644 --- a/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_scal2v.h" -class sscal2vGenericTest : +class sscal2vGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sscal2vGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sscal2vGeneric); // Tests using random integers as vector elements. -TEST_P( sscal2vGenericTest, RandomData ) +TEST_P( sscal2vGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -64,44 +64,28 @@ TEST_P( sscal2vGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class sscal2vGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - float alpha = std::get<4>(str.param); - std::string str_name = "bli_sscal2v"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of sscal2. INSTANTIATE_TEST_SUITE_P( Blackbox, - sscal2vGenericTest, + sscal2vGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -109,7 +93,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(float(3.0), float(-5.0)) // alpha ), - ::sscal2vGenericTestPrint() + ::scal2vGenericPrint() ); // Test when conjugate of x is used as an argument. This option is BLIS-api specific. @@ -117,7 +101,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( Conjalpha, - sscal2vGenericTest, + sscal2vGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conjugate ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. @@ -125,7 +109,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(float(9.0)) // alpha ), - ::sscal2vGenericTestPrint() + ::scal2vGenericPrint() ); // Test for non-unit increments. @@ -133,7 +117,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - sscal2vGenericTest, + sscal2vGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. @@ -141,6 +125,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(7)), // stride size for y ::testing::Values(float(2.0)) // alpha ), - ::sscal2vGenericTestPrint() + ::scal2vGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/scal2v/test_scal2v.h b/gtestsuite/testsuite/level1/scal2v/test_scal2v.h index 9cb621acb6..8be02dc619 100644 --- a/gtestsuite/testsuite/level1/scal2v/test_scal2v.h +++ b/gtestsuite/testsuite/level1/scal2v/test_scal2v.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -66,5 +66,27 @@ static void test_scal2v(char conjx, gtint_t n, gtint_t incx, gtint_t incy, T alp //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } + +// Test-case logger : Used to print the test-case details based on parameters +template +class scal2vGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp index 5c413192d6..2249ce4a08 100644 --- a/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,18 +35,18 @@ #include #include "test_scal2v.h" -class zscal2vGenericTest : +class zscal2vGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscal2vGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscal2vGeneric); // Tests using random integers as vector elements. -TEST_P( zscal2vGenericTest, RandomData ) +TEST_P( zscal2vGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -65,44 +65,29 @@ TEST_P( zscal2vGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); + // Check gtestsuite dotxv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zscal2vGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - dcomplex alpha = std::get<4>(str.param); - std::string str_name = "bli_zscal2v"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of cscal2. INSTANTIATE_TEST_SUITE_P( Blackbox, - zscal2vGenericTest, + zscal2vGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -110,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}) // alpha ), - ::zscal2vGenericTestPrint() + ::scal2vGenericPrint() ); @@ -119,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - zscal2vGenericTest, + zscal2vGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -127,6 +112,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3)), // stride size for y ::testing::Values(dcomplex{1.0, 2.1}) // alpha ), - ::zscal2vGenericTestPrint() + ::scal2vGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/scalv/IIT_ERS/scalv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/scalv/IIT_ERS/scalv_IIT_ERS.cpp new file mode 100644 index 0000000000..58b48fdbae --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/IIT_ERS/scalv_IIT_ERS.cpp @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class scalv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types< + // std::pair + std::pair< float, float>, + std::pair< double, double>, + std::pair, + std::pair, + std::pair, + std::pair + > TypeParam; +TYPED_TEST_SUITE(scalv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + BLAS Early Return Scenarios(ERS): + + SCALV is expected to return early in the following cases: + 1. n <= 0 + 2. inc <= 0 + 3. alpha == 1 +*/ + +// n < 0, with non-unit stride +TYPED_TEST(scalv_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_n = -1; + gtint_t inc = 5; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', invalid_n, alpha, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', invalid_n, alpha, x.data(), inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), inc ); +} + +// n == 0, with non-unit stride +TYPED_TEST(scalv_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_n = 0; + gtint_t inc = 5; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', invalid_n, alpha, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', invalid_n, alpha, x.data(), inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), inc ); +} + +// n < 0, with unit stride +TYPED_TEST(scalv_IIT_ERS, n_lt_zero_unitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_n = -1; + gtint_t unit_inc = 1; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', invalid_n, alpha, nullptr, unit_inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', invalid_n, alpha, x.data(), unit_inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), unit_inc ); +} + +// n == 0, with unit stride +TYPED_TEST(scalv_IIT_ERS, n_eq_zero_unitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_n = 0; + gtint_t unit_inc = 1; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', invalid_n, alpha, nullptr, unit_inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', invalid_n, alpha, x.data(), unit_inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), unit_inc ); +} + +// inc < 0 +TYPED_TEST(scalv_IIT_ERS, inc_lt_0) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_inc = -1; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', N, alpha, nullptr, invalid_inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', N, alpha, x.data(), invalid_inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), INC ); +} + +// inc == 0 +TYPED_TEST(scalv_IIT_ERS, inc_eq_0) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t invalid_inc = 0; + // Using alpha = 3 as a valid input since BLAS expects SCALV to return early + // for alpha = 1. + RT alpha = RT{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', N, alpha, nullptr, invalid_inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', N, alpha, x.data(), invalid_inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), INC ); +} + +// alpha == 1, with non-unit stride +TYPED_TEST(scalv_IIT_ERS, alpha_eq_one_nonUnitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t inc = 5; + RT invalid_alpha; + testinghelpers::initone(invalid_alpha); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', N, invalid_alpha, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', N, invalid_alpha, x.data(), inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), inc ); +} + +// alpha == 1, with unit stride +TYPED_TEST(scalv_IIT_ERS, alpha_eq_one_unitStride) +{ + using T = typename TypeParam::first_type; + using RT = typename TypeParam::second_type; + gtint_t unit_inc = 1; + RT invalid_alpha; + testinghelpers::initone(invalid_alpha); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + scalv( 'n', N, invalid_alpha, nullptr, unit_inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + std::vector x_ref(x); // copy x to x_ref to verify elements of x are not modified. + + // Invoking SCALV with an invalid value of n. + scalv( 'n', N, invalid_alpha, x.data(), unit_inc ); + + // Computing bitwise difference. + computediff( "x", N, x.data(), x_ref.data(), unit_inc ); +} +#endif diff --git a/gtestsuite/testsuite/level1/scalv/cscalv/cscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/cscalv/cscalv_generic.cpp new file mode 100644 index 0000000000..1c35cd9693 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/cscalv/cscalv_generic.cpp @@ -0,0 +1,231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class cscalvGeneric : + public ::testing::TestWithParam> {}; // alpha + + +// Tests using random integers as vector elements. +TEST_P( cscalvGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + cscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + cscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + cscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(9), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + cscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +#ifndef TEST_BLIS_TYPED +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + cscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{ 0.0, 0.0} + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjalpha, + cscalvGeneric, + ::testing::Combine( + ::testing::Values('c'), // c: use conjugate + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(scomplex{ 7.3, 5.1}) // alpha + ), + (::scalvGenericPrint()) + ); +#endif + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + cscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(scomplex{ 7.3, 5.1}) // alpha + ), + (::scalvGenericPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp deleted file mode 100644 index bf367f73d8..0000000000 --- a/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp +++ /dev/null @@ -1,152 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_scalv.h" - -class cscalvGenericTest : - public ::testing::TestWithParam> {}; - - -// Tests using random integers as vector elements. -TEST_P( cscalvGenericTest, RandomData ) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv( conj_alpha, n, incx, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class cscalvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - scomplex alpha = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cscal_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cscal"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cscalv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of cscal. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - cscalvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha - ), - ::cscalvGenericTestPrint() - ); - - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - cscalvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(scomplex{4.0, 3.1}) // alpha - ), - ::cscalvGenericTestPrint() - ); - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - cscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(scomplex{4.0, 3.1}) // alpha - ), - ::cscalvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/scalv/csscalv/csscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/csscalv/csscalv_generic.cpp new file mode 100644 index 0000000000..13aaff8d0a --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/csscalv/csscalv_generic.cpp @@ -0,0 +1,219 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class csscalvGeneric : + public ::testing::TestWithParam> {}; // alpha + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( csscalvGeneric, API ) +{ + using T = scomplex; + using U = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + U alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// bli_csscal not present in BLIS +#ifndef TEST_BLIS_TYPED + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + csscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + csscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + csscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(9), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + csscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + csscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + csscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(3) // alpha + ), + (::scalvGenericPrint()) + ); + +#endif // not TEST_BLIS_TYPED + + + + + + diff --git a/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_evt.cpp b/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_evt.cpp new file mode 100644 index 0000000000..58ee91d01c --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_evt.cpp @@ -0,0 +1,340 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class dscalvEVT : + public ::testing::TestWithParam> {}; // alpha + + +// Tests using random integers as vector elements. +TEST_P( dscalvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // index of extreme value for x: + gtint_t xi = std::get<3>(GetParam()); + // extreme value for x: + double x_exval = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + + // Set the threshold for the errors: + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, xi, x_exval, alpha, thresh ); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Tests for Zen4 Architecture. +/** + * bli_dscalv_zen_int_avx512( ... ) + * Loops: + * L64 - Main loop, handles 64 elements + * L32 - handles 32 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) + * + * n = 383 : L64*5 + L20 + L16 + L8 + L4 + L2 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 319 - L64 + * 351 - L32 + * 367 - L16 + * 375 - L8 + * 379 - L4 + * 380 - L2 + * 382 - LScalar + */ +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride_zen4, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(383) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(319), gtint_t(351), + gtint_t(367), gtint_t(375), gtint_t(379), + gtint_t(380), gtint_t(382) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // alpha: value of scalar. + ::testing::Values( + double(-3.3), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) + ); + +// Tests for Zen3 Architecture. +/** + * bli_dscalv_zen_int10( ... ) + * Loops: + * L64 - Main loop, handles 64 elements + * L48 - handles 48 elements + * L32 - handles 32 elements + * L12 - handles 12 elements + * L4 - handles 4 elements + * LScalar - leftover loop + * + * n = 565 : L64*8 + L48 + L4 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 511 - L64 + * 520, 525 - L48 + * 528, 555 - L48 + * 561 - L4 + * 564 - LScalar + * + * n = 556 : L64*8 + L32 + L12 + * Indices - Loop into which extreme value is induced + * 0, 511 - L64 + * 520, 525 - L32 + * 555 - L12 + * + * n = 529 : L64*8 + L12 + L4 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 511 - L64 + * 520 - L12 + * 525 - L4 + * 528 - LScalar + */ +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride_zen3, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(565), + gtint_t(556), + gtint_t(529) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(0), gtint_t(511), gtint_t(520), + gtint_t(525), gtint_t(528), gtint_t(555), + gtint_t(561), gtint_t(564) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // alpha: value of scalar. + ::testing::Values( + double(-3.3), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) + ); + +// EVT with non-unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStride, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3) + ), + // xi: index of extreme value for x. + ::testing::Values( + gtint_t(1), gtint_t(27), gtint_t(51) + ), + // x_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // alpha: value of scalar. + ::testing::Values( + double(-3.3), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) + ); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_unitStride_zen3, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(565), + gtint_t(556), + gtint_t(529) + ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1) ), + // x_exval: extreme value for x. + ::testing::Values( double(0.0) ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) + ); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_unitStride_zen4, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(383) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1) ), + // x_exval: extreme value for x. + ::testing::Values( double(0.0) ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) + ); + +// EVT with alpha containing Infs/NaNs on a non-unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStride, + dscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , + 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1) ), + // x_exval: extreme value for x. + ::testing::Values( double(0.0) ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) + ); diff --git a/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_generic.cpp new file mode 100644 index 0000000000..c6dded5264 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/dscalv/dscalv_generic.cpp @@ -0,0 +1,289 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class dscalvGeneric : + public ::testing::TestWithParam> {}; + + +// Tests using random integers as vector elements. +TEST_P( dscalvGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(9), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +#ifndef TEST_BLIS_TYPED +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0) + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjalpha, + dscalvGeneric, + ::testing::Combine( + ::testing::Values('c'), // c: use conjugate + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(double(-3.0)) // alpha + ), + (::scalvGenericPrint()) + ); +#endif + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + dscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(3) // alpha + ), + (::scalvGenericPrint()) + ); +#endif + +#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC) +INSTANTIATE_TEST_SUITE_P( + AOCLDynamic, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 30000), // nt_ideal = 1 + gtint_t( 100000), // nt_ideal = 2 + gtint_t( 486919), // nt_ideal = 8 + gtint_t( 500000), // nt_ideal = 8 + gtint_t( 2500000), // nt_ideal = 12 + gtint_t( 4000000), // nt_ideal = 16 + gtint_t( 7000000), // nt_ideal = 24 + gtint_t(10000000), // nt_ideal = 32 + gtint_t(25000000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(3) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0) + ) + ), + (::scalvGenericPrint()) + ); + +#ifndef TEST_BLIS_TYPED +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + AOCLDynamicAlphaZero, + dscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 89), // nt_ideal = 8 + gtint_t( 486919), // nt_ideal = 8 + gtint_t(25000000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(3) + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0) + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#endif // BLIS_ENABLE_OPENMP && AOCL_DYNAMIC diff --git a/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp deleted file mode 100644 index b73db053c6..0000000000 --- a/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_scalv.h" - -class dscalvGenericTest : - public ::testing::TestWithParam> {}; - - -// Tests using random integers as vector elements. -TEST_P( dscalvGenericTest, RandomData ) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv( conj_alpha, n, incx, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class dscalvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - double alpha = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dscal_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dscal"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dscalv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of dscal. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - dscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, not conj(x) (since it is real) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(double(2.0), double(-3.0)) // alpha - ), - ::dscalvGenericTestPrint() - ); - -#ifdef TEST_BLIS_TYPED -// Test when conjugate of x is used as an argument. This option is BLIS-api specific. -// Only test very few cases as sanity check since conj(x) = x for real types. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - Conjalpha, - dscalvGenericTest, - ::testing::Combine( - ::testing::Values('c'), // c: use conjugate - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(double(-3.0)) // alpha - ), - ::dscalvGenericTestPrint() - ); -#endif - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - dscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(double(3.0)) // alpha - ), - ::dscalvGenericTestPrint() - ); - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - dscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(3) // alpha - ), - ::dscalvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/scalv/scalv.h b/gtestsuite/testsuite/level1/scalv/scalv.h index 0ae0125f52..ba7641f6cd 100644 --- a/gtestsuite/testsuite/level1/scalv/scalv.h +++ b/gtestsuite/testsuite/level1/scalv/scalv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -48,65 +49,145 @@ * @param[in] incx increment of x */ -template -static void scalv_(gtint_t n, T alpha, T* x, gtint_t incx) +template +static void scalv_(gtint_t n, U alpha, T* x, gtint_t incx) { - if constexpr (std::is_same::value) - sscal_( &n, &alpha, x, &incx ); - else if constexpr (std::is_same::value) - dscal_( &n, &alpha, x, &incx ); - else if constexpr (std::is_same::value) - cscal_( &n, &alpha, x, &incx ); - else if constexpr (std::is_same::value) - zscal_( &n, &alpha, x, &incx ); + if constexpr (std::is_same::value) + { + if constexpr (std::is_same::value) + sscal_( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + dscal_( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + cscal_( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + zscal_( &n, &alpha, x, &incx ); + else + throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in scalv_()."); + } + else if constexpr (std::is_same::value && std::is_same::value ) + csscal_( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value && std::is_same::value ) + zdscal_( &n, &alpha, x, &incx ); else throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in scalv_()."); } -template -static void cblas_scalv(gtint_t n, T alpha, T* x, gtint_t incx) +template +static void scalv_blis_impl(gtint_t n, U alpha, T* x, gtint_t incx) { - if constexpr (std::is_same::value) - cblas_sscal( n, alpha, x, incx ); - else if constexpr (std::is_same::value) - cblas_dscal( n, alpha, x, incx ); - else if constexpr (std::is_same::value) - cblas_cscal( n, &alpha, x, incx ); - else if constexpr (std::is_same::value) - cblas_zscal( n, &alpha, x, incx ); + if constexpr (std::is_same::value) + { + if constexpr (std::is_same::value) + sscal_blis_impl( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + dscal_blis_impl( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + cscal_blis_impl( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value) + zscal_blis_impl( &n, &alpha, x, &incx ); + else + throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in scalv_blis_impl()."); + } + else if constexpr (std::is_same::value && std::is_same::value ) + csscal_blis_impl( &n, &alpha, x, &incx ); + else if constexpr (std::is_same::value && std::is_same::value ) + zdscal_blis_impl( &n, &alpha, x, &incx ); + else + throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in scalv_blis_impl()."); +} + +template +static void cblas_scalv(gtint_t n, U alpha, T* x, gtint_t incx) +{ + if constexpr (std::is_same::value) + { + if constexpr (std::is_same::value) + cblas_sscal( n, alpha, x, incx ); + else if constexpr (std::is_same::value) + cblas_dscal( n, alpha, x, incx ); + else if constexpr (std::is_same::value) + cblas_cscal( n, &alpha, x, incx ); + else if constexpr (std::is_same::value) + cblas_zscal( n, &alpha, x, incx ); + else + throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in cblas_scalv()."); + } + else if constexpr (std::is_same::value && std::is_same::value ) + cblas_csscal( n, alpha, x, incx ); + else if constexpr (std::is_same::value && std::is_same::value ) + cblas_zdscal( n, alpha, x, incx ); else throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in cblas_scalv()."); } -template -static void typed_scalv(char conj_alpha, gtint_t n, T alpha, T* x, gtint_t incx) +template +static void typed_scalv(char conj_alpha, gtint_t n, U alpha, T* x, gtint_t incx) { conj_t conjalpha; // Map parameter characters to BLIS constants. testinghelpers::char_to_blis_conj( conj_alpha, &conjalpha ); - if constexpr (std::is_same::value) - bli_sscalv( conjalpha, n, &alpha, x, incx ); - else if constexpr (std::is_same::value) - bli_dscalv( conjalpha, n, &alpha, x, incx ); - else if constexpr (std::is_same::value) - bli_cscalv( conjalpha, n, &alpha, x, incx ); - else if constexpr (std::is_same::value) - bli_zscalv( conjalpha, n, &alpha, x, incx ); + + if constexpr (std::is_same::value) + { + if constexpr (std::is_same::value) + bli_sscalv( conjalpha, n, &alpha, x, incx ); + else if constexpr (std::is_same::value) + bli_dscalv( conjalpha, n, &alpha, x, incx ); + else if constexpr (std::is_same::value) + bli_cscalv( conjalpha, n, &alpha, x, incx ); + else if constexpr (std::is_same::value) + bli_zscalv( conjalpha, n, &alpha, x, incx ); + else + throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in typed_scalv()."); + } + // Disabled BLIS_TYPED tests for mixed-precision SCALV as BLIS isn't exposing these functions. +#if 0 + else if constexpr (std::is_same::value && std::is_same::value ) + bli_csscalv( conjalpha, n, &alpha, x, incx ); + else if constexpr (std::is_same::value && std::is_same::value ) + bli_zdscalv( conjalpha, n, &alpha, x, incx ); +#endif else throw std::runtime_error("Error in testsuite/level1/scalv.h: Invalid typename in typed_scalv()."); } - -template -static void scalv(char conj_alpha, gtint_t n, T alpha, T* x, gtint_t incx) +template +static void scalv(char conj_alpha, gtint_t n, U alpha, T* x, gtint_t incx) { + +#ifdef TEST_UPPERCASE_ARGS + conj_alpha = static_cast(std::toupper(static_cast(conj_alpha))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_alpha_cpy = conj_alpha; + gtint_t n_cpy = n; + U alpha_cpy = alpha; + gtint_t incx_cpy = incx; +#endif + #ifdef TEST_BLAS - scalv_( n, alpha, x, incx ); + scalv_( n, alpha, x, incx ); +#elif TEST_BLAS_BLIS_IMPL + scalv_blis_impl( n, alpha, x, incx ); #elif TEST_CBLAS - cblas_scalv( n, alpha, x, incx ); + cblas_scalv( n, alpha, x, incx ); #elif TEST_BLIS_TYPED - typed_scalv( conj_alpha, n, alpha, x, incx ); + typed_scalv( conj_alpha, n, alpha, x, incx ); #else throw std::runtime_error("Error in testsuite/level1/scalv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_alpha", conj_alpha, conj_alpha_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", alpha, alpha_cpy ); + computediff( "incx", incx, incx_cpy ); +#endif } diff --git a/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp b/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp deleted file mode 100644 index 9ac6c0d4ed..0000000000 --- a/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_scalv.h" - -template -class xscalv : public ::testing::Test {}; -typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(xscalv, TypeParam); - -TYPED_TEST(xscalv, zero_alpha_x_fp) -{ - using T = TypeParam; - gtint_t n = 10, incx = 1; - std::vector x(n); - // Initialize x with random numbers. - testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x.data(), BLIS_ELEMENT_TYPE ); - std::vector x_ref(x); - T alpha = T{0}; - - testinghelpers::ref_scalv('n', n, alpha, x_ref.data(), incx); - //---------------------------------------------------------- - // Call BLIS function. - //---------------------------------------------------------- - scalv('n', n, alpha, x.data(), incx); - - //---------------------------------------------------------- - // Compute component-wise error. - //---------------------------------------------------------- - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - computediff( n, x.data(), x_ref.data(), incx, thresh ); -} - -TYPED_TEST(xscalv, zero_alpha_x_inf) -{ - using T = TypeParam; - gtint_t n = 10, incx = 1; - std::vector x(n); - // Initialize x with random numbers. - testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x.data(), BLIS_ELEMENT_TYPE ); - x[3] = 1.0/0.0; - std::vector x_ref(x); - T alpha = T{0}; - testinghelpers::ref_scalv('n', n, alpha, x_ref.data(), incx); - - //---------------------------------------------------------- - // Call BLIS function. - //---------------------------------------------------------- - scalv('n', n, alpha, x.data(), incx); - - //---------------------------------------------------------- - // Compute component-wise error. - //---------------------------------------------------------- - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - computediff( n, x.data(), x_ref.data(), incx, thresh ); -} diff --git a/gtestsuite/testsuite/level1/scalv/sscalv/sscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/sscalv/sscalv_generic.cpp new file mode 100644 index 0000000000..eeaaac46d8 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/sscalv/sscalv_generic.cpp @@ -0,0 +1,226 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class sscalvGeneric : + public ::testing::TestWithParam> {}; + + +// Tests using random integers as vector elements. +TEST_P( sscalvGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + float thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// Black box testing for generic use of sscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + sscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + sscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + sscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(17), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + sscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + float( 7.0), + float(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +#ifndef TEST_BLIS_TYPED +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + sscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + float( 0.0) + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + Conjalpha, + sscalvGeneric, + ::testing::Combine( + ::testing::Values('c'), // c: use conjugate + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(float(-3.0)) // alpha + ), + (::scalvGenericPrint()) + ); +#endif + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + sscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(3) // alpha + ), + (::scalvGenericPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp deleted file mode 100644 index e00f5effa2..0000000000 --- a/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp +++ /dev/null @@ -1,160 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_scalv.h" - -class sscalvGenericTest : - public ::testing::TestWithParam> {}; - - -// Tests using random integers as vector elements. -TEST_P( sscalvGenericTest, RandomData ) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv( conj_alpha, n, incx, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class sscalvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - float alpha = std::get<3>(str.param); - #ifdef TEST_BLAS - std::string str_name = "sscal_"; - #elif TEST_CBLAS - std::string str_name = "cblas_sscal"; - #else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sscalv"; - #endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of sscal. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - sscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, not conj(x) (since it is real) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(float(3.0), float(-5.0)) // alpha - ), - ::sscalvGenericTestPrint() - ); - -#ifdef TEST_BLIS_TYPED -// Test when conjugate of x is used as an argument. This option is BLIS-api specific. -// Only test very few cases as sanity check since conj(x) = x for real types. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - Conjalpha, - sscalvGenericTest, - ::testing::Combine( - ::testing::Values('c'), // c: use conjugate - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(float(9.0)) // alpha - ), - ::sscalvGenericTestPrint() - ); -#endif - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - sscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x - ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(float(2.0)) // alpha - ), - ::sscalvGenericTestPrint() - ); - - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - sscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(3) // alpha - ), - ::sscalvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/scalv/test_scalv.h b/gtestsuite/testsuite/level1/scalv/test_scalv.h index 4c5437d722..e4663da970 100644 --- a/gtestsuite/testsuite/level1/scalv/test_scalv.h +++ b/gtestsuite/testsuite/level1/scalv/test_scalv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,31 +39,110 @@ #include "inc/check_error.h" /** - * @brief Generic test body for axpby operation. + * @brief Generic test body for scalv operation. */ +template +static void test_scalv( char conja_alpha, gtint_t n, gtint_t incx, U alpha, double thresh ) +{ + //---------------------------------------------------------- + // Initialize vector with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + if (alpha == testinghelpers::ZERO()) + testinghelpers::set_vector( n, incx, x.data(), testinghelpers::aocl_extreme() ); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector x_ref(x); + testinghelpers::ref_scalv( conja_alpha, n, alpha, x_ref.data(), incx ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + scalv( conja_alpha, n, alpha, x.data(), incx ); -template -static void test_scalv( char conja_alpha, gtint_t n, gtint_t incx, T alpha, double thresh ) + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "x", n, x.data(), x_ref.data(), incx, thresh, true ); +} + +/** + * @brief Used to insert Exception Values in x vector. + */ +template +static void test_scalv( char conja_alpha, gtint_t n, gtint_t incx, gtint_t xi, + T x_exval, U alpha, double thresh ) { //---------------------------------------------------------- // Initialize vector with random numbers. //---------------------------------------------------------- std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * incx] = x_exval; + else return; + //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- // Create a copy of y so that we can check reference results. std::vector x_ref(x); - testinghelpers::ref_scalv( conja_alpha, n, alpha, x_ref.data(), incx ); + testinghelpers::ref_scalv( conja_alpha, n, alpha, x_ref.data(), incx ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - scalv( conja_alpha, n, alpha, x.data(), incx ); + scalv( conja_alpha, n, alpha, x.data(), incx ); //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, x.data(), x_ref.data(), incx, thresh ); + computediff( "x", n, x.data(), x_ref.data(), incx, thresh, true ); } + + +// Test-case logger : Used to print the test-case details based on parameters +template +class scalvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjalpha = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + U alpha = std::get<3>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjalpha_" + std::string(&conjalpha, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + return str_name; + } +}; + +template +class scalvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjalpha = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t xi = std::get<3>(str.param); + T x_exval = std::get<4>(str.param); + U alpha = std::get<5>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjalpha_" + std::string(&conjalpha, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + testinghelpers::get_value_string(x_exval); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_evt.cpp b/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_evt.cpp new file mode 100644 index 0000000000..fbaaf4dd74 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_evt.cpp @@ -0,0 +1,370 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class zdscalvEVT : + public ::testing::TestWithParam> {}; // alpha + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdscalvEVT); + +// Tests using random integers as vector elements. +TEST_P( zdscalvEVT, API ) +{ + using T = dcomplex; + using RT = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // index of extreme value for x: + gtint_t xi = std::get<3>(GetParam()); + // extreme value for x: + T x_exval = std::get<4>(GetParam()); + // alpha: + RT alpha = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, xi, x_exval, alpha, thresh ); +} + +// bli_zdscal not present in BLIS +#ifndef TEST_BLIS_TYPED + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Tests for Zen3 Architecture. +/** + * Tests for bli_zdscalv_zen_int10 (AVX2) kernel. + * Loops: + * L30 - Main loop, handles 30 elements + * L24 - handles 24 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) + * + * n = 105 : L30*3 + L8 + L4 + L2 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 69 - L30 + * 97 - L8 + * 101 - L4 + * 103 - L2 + * 104 - LScalar + * + * n = 79 : L30*2 + L16 + L2 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 58 - L30 + * 69 - L16 + * 77 - L2 + * 78 - LScalar + * + * n = 59 : L30 + L24 + L4 + LScalar + * Indices - Loop into which extreme value is induced + * 0 - L30 + * 51 - L24 + * 55 - L4 + * 58 - LScalar +*/ +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride_zen3, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(105), + gtint_t( 79), + gtint_t( 59) + ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( // n = 105 + gtint_t(0), // L30 + gtint_t(97), // L8 + gtint_t(101), // L4 + gtint_t(103), // L2 + gtint_t(104), // LScalar + + // n = 79 + gtint_t(69), // L16 + gtint_t(77), // L2 + gtint_t(78), // LScalar + + // n = 59 + gtint_t(51), // L24 + gtint_t(55), // L4 + gtint_t(58) // LScalar + ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{ NaN, 0.0}, + dcomplex{ Inf, 0.0}, + dcomplex{-Inf, 0.0}, + dcomplex{ 0.0, Inf}, + dcomplex{-2.1, NaN}, + dcomplex{ 1.2, -Inf}, + dcomplex{ NaN, Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ NaN, NaN}, + dcomplex{ Inf, -Inf} + ), + // alpha: value of scalar. + ::testing::Values( double(-5.1), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) +); + +// Tests for Zen4 Architecture. +/** + * Tests for bli_zdscalv_zen_int_avx512 (AVX512) kernel. + * Loops: + * L16 - Main loop, handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) + * + * n = 63 : L16*3 + L8 + L4 + L2 + LScalar + * Indices - Loop into which extreme value is induced + * 0, 31 - L16 + * 48 - L8 + * 56 - L4 + * 60 - L2 + * 62 - LScalar +*/ +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride_zen4, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(63) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( // n = 63 + gtint_t(0), // L16 + gtint_t(31), // l16 + gtint_t(48), // L8 + gtint_t(56), // L4 + gtint_t(60), // L2 + gtint_t(62) // LScalar + ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{ NaN, 0.0}, + dcomplex{ Inf, 0.0}, + dcomplex{-Inf, 0.0}, + dcomplex{ 0.0, Inf}, + dcomplex{-2.1, NaN}, + dcomplex{ 1.2, -Inf}, + dcomplex{ NaN, Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ NaN, NaN}, + dcomplex{ Inf, -Inf} + ), + // alpha: value of scalar. + ::testing::Values( double(-5.1), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) +); + +// EVT with non-unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStride, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{ NaN, 0.0}, + dcomplex{ Inf, 0.0}, + dcomplex{-Inf, 0.0}, + dcomplex{ 0.0, Inf}, + dcomplex{-2.1, NaN}, + dcomplex{ 1.2, -Inf}, + dcomplex{ NaN, Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ NaN, NaN}, + dcomplex{ Inf, -Inf} + ), + // alpha: value of scalar. + ::testing::Values( double(-5.1), + double(-1.0), + double( 0.0), + double( 1.0), + double( 7.3) + ) + ), + (::scalvEVTPrint()) +); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_unitStride_zen3, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(105), + gtint_t( 79), + gtint_t( 59) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{0.0, 0.0} ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) +); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_unitStride_zen4, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(63) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{0.0, 0.0} ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) +); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStride, + zdscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{0.0, 0.0} ), + // alpha: value of scalar. + ::testing::Values( NaN, Inf, -Inf ) + ), + (::scalvEVTPrint()) +); + +#endif // not TEST_BLIS_TYPED diff --git a/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_generic.cpp new file mode 100644 index 0000000000..f879d1afe9 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/zdscalv/zdscalv_generic.cpp @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class zdscalvGeneric : + public ::testing::TestWithParam> {}; // alpha + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zdscalvGeneric, API ) +{ + using T = dcomplex; + using U = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + U alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// bli_zdscal not present in BLIS +#ifndef TEST_BLIS_TYPED + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(9), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0), + double(-3.0) + ) + ), + (::scalvGenericPrint()) + ); + +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0) + ) + ), + (::scalvGenericPrint()) + ); + +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + zdscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(3) // alpha + ), + (::scalvGenericPrint()) + ); + +#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC) +INSTANTIATE_TEST_SUITE_P( + AOCLDynamic, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 10000), // nt_ideal = 1 + gtint_t( 20000), // nt_ideal = 4 + gtint_t( 486919), // nt_ideal = 8 + gtint_t( 1000000), // nt_ideal = 8 + gtint_t( 2500000), // nt_ideal = 12 + gtint_t( 5000000), // nt_ideal = 32 + gtint_t( 7000000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(3) + ), + // alpha: value of scalar. + ::testing::Values( + double( 7.0) + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + AOCLDynamicAlphaZero, + zdscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t( 486919), // nt_ideal = 8 + gtint_t( 7000000) // nt_ideal = max_available + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(3) + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0) + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#endif // not TEST_BLIS_TYPED diff --git a/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_evt.cpp b/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_evt.cpp new file mode 100644 index 0000000000..e597ffb95f --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_evt.cpp @@ -0,0 +1,246 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class zscalvEVT : + public ::testing::TestWithParam> {}; // alpha + + +// Tests using random integers as vector elements. +TEST_P( zscalvEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // index of extreme value for x: + gtint_t xi = std::get<3>(GetParam()); + // extreme value for x: + T x_exval = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, xi, x_exval, alpha, thresh ); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Tests for Zen3 Architecture. +/** + * Tests for bli_zscalv_zen_int (AVX2) kernel. + * Loops: + * L8 - Main loop, handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride_zen3, + zscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(71) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(64), gtint_t(67), + gtint_t(69), gtint_t(70) + ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{ NaN, 0.0}, + dcomplex{ Inf, 0.0}, + dcomplex{-Inf, 0.0}, + dcomplex{ 0.0, Inf}, + dcomplex{-2.1, NaN}, + dcomplex{ 1.2, -Inf}, + dcomplex{ NaN, Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ NaN, NaN}, + dcomplex{ Inf, -Inf} + ), + // alpha: value of scalar. + ::testing::Values( dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvEVTPrint()) +); + +// EVT with non-unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStride, + zscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{ NaN, NaN}, + dcomplex{ NaN, Inf}, + dcomplex{ NaN, -Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ Inf, Inf}, + dcomplex{ Inf, -Inf}, + dcomplex{-Inf, NaN}, + dcomplex{-Inf, Inf}, + dcomplex{-Inf, -Inf} + ), + // alpha: value of scalar. + ::testing::Values( dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvEVTPrint()) +); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_unitStride_zen3, + zscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(71) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{0.0, 0.0} ), + // alpha: value of scalar. + ::testing::Values( dcomplex{ NaN, NaN}, + dcomplex{ NaN, Inf}, + dcomplex{ NaN, -Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ Inf, Inf}, + dcomplex{ Inf, -Inf}, + dcomplex{-Inf, NaN}, + dcomplex{-Inf, Inf}, + dcomplex{-Inf, -Inf} + ) + ), + (::scalvEVTPrint()) +); + +// EVT with alpha containing Infs/NaNs on a unit stride vector. +INSTANTIATE_TEST_SUITE_P( + alpha_nonUnitStride, + zscalvEVT, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values( 'n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjugate option is BLIS-api specific. +#endif + ), + // m: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0) ), + // x_exval: extreme value for x. + ::testing::Values( dcomplex{0.0, 0.0} ), + // alpha: value of scalar. + ::testing::Values( dcomplex{ NaN, NaN}, + dcomplex{ NaN, Inf}, + dcomplex{ NaN, -Inf}, + dcomplex{ Inf, NaN}, + dcomplex{ Inf, Inf}, + dcomplex{ Inf, -Inf}, + dcomplex{-Inf, NaN}, + dcomplex{-Inf, Inf}, + dcomplex{-Inf, -Inf} + ) + ), + (::scalvEVTPrint()) +); diff --git a/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_generic.cpp new file mode 100644 index 0000000000..f908fa8d17 --- /dev/null +++ b/gtestsuite/testsuite/level1/scalv/zscalv/zscalv_generic.cpp @@ -0,0 +1,231 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level1/scalv/test_scalv.h" + +class zscalvGeneric : + public ::testing::TestWithParam> {}; // alpha + + +// Tests using random integers as vector elements. +TEST_P( zscalvGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); +} + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementSmall, + zscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +// Black box testing for generic use of dscal. +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrementLarge, + zscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementSmall, + zscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(9), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrementLarge, + zscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values(gtint_t(111), gtint_t(193), gtint_t(403)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 1.0, 1.0}, + dcomplex{ 7.3, 5.1} + ) + ), + (::scalvGenericPrint()) + ); + +#ifndef TEST_BLIS_TYPED +// alpha=0 testing only for BLAS and CBLAS as +// BLIS uses setv and won't propagate Inf and NaNs +INSTANTIATE_TEST_SUITE_P( + alphaZero, + zscalvGeneric, + ::testing::Combine( + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Range(gtint_t(1), gtint_t(101), 1), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1), + gtint_t(2), + gtint_t(41) + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{ 0.0, 0.0} + ) + ), + (::scalvGenericPrint()) + ); +#endif + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjalpha, + zscalvGeneric, + ::testing::Combine( + ::testing::Values('c'), // c: use conjugate + ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(dcomplex{ 7.3, 5.1}) // alpha + ), + (::scalvGenericPrint()) + ); +#endif + +#ifndef TEST_BLIS_TYPED +// Test for negative increments. +// Only test very few cases as sanity check. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + NegativeIncrements, + zscalvGeneric, + ::testing::Combine( + ::testing::Values('n'), // n: use x, c: use conj(x) + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. + ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x + ::testing::Values(dcomplex{ 7.3, 5.1}) // alpha + ), + (::scalvGenericPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp deleted file mode 100644 index 66419cbd4c..0000000000 --- a/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp +++ /dev/null @@ -1,152 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_scalv.h" - -class zscalvGenericTest : - public ::testing::TestWithParam> {}; - - -// Tests using random integers as vector elements. -TEST_P( zscalvGenericTest, RandomData ) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv( conj_alpha, n, incx, alpha, thresh ); -} - -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zscalvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - dcomplex alpha = std::get<3>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zscal_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zscal"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zscalv"; -#endif - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - return str_name; - } -}; - -// Black box testing for generic and main use of cscal. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - zscalvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}) // alpha - ), - ::zscalvGenericTestPrint() - ); - - -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - zscalvGenericTest, - ::testing::Combine( - ::testing::Values('n' -#ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. -#endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(dcomplex{1.0, 2.1}) // alpha - ), - ::zscalvGenericTestPrint() - ); - -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. -INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - zscalvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(dcomplex{4.0, 3.1}) // alpha - ), - ::zscalvGenericTestPrint() - ); -#endif diff --git a/gtestsuite/testsuite/level1/setv/csetv_generic.cpp b/gtestsuite/testsuite/level1/setv/csetv_generic.cpp index 2a2daf72fd..18d4d590c8 100644 --- a/gtestsuite/testsuite/level1/setv/csetv_generic.cpp +++ b/gtestsuite/testsuite/level1/setv/csetv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_setv.h" -class csetvGenericTest : +class csetvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csetvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csetvGeneric); -TEST_P( csetvGenericTest, RandomData ) +TEST_P( csetvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -61,33 +61,16 @@ TEST_P( csetvGenericTest, RandomData ) test_setv( conjalpha, n, alpha, incx ); } -// Prints the test case combination -class csetvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - std::string str_name = "bli_csetv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - csetvGenericTest, + csetvGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)) // stride size for x ), - ::csetvGenericTestPrint() + ::setvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/setv/dsetv_generic.cpp b/gtestsuite/testsuite/level1/setv/dsetv_generic.cpp index 6051169bbc..cf3ce4089f 100644 --- a/gtestsuite/testsuite/level1/setv/dsetv_generic.cpp +++ b/gtestsuite/testsuite/level1/setv/dsetv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_setv.h" -class dsetvGenericTest : +class dsetvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsetvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsetvGeneric); -TEST_P( dsetvGenericTest, RandomData ) +TEST_P( dsetvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -61,33 +61,16 @@ TEST_P( dsetvGenericTest, RandomData ) test_setv( conjalpha, n, alpha, incx ); } -// Prints the test case combination -class dsetvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - std::string str_name = "bli_dsetv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dsetvGenericTest, + dsetvGeneric, ::testing::Combine( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)) // stride size for x ), - ::dsetvGenericTestPrint() + ::setvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/setv/setv.h b/gtestsuite/testsuite/level1/setv/setv.h index 651ec36b90..c16c35b81d 100644 --- a/gtestsuite/testsuite/level1/setv/setv.h +++ b/gtestsuite/testsuite/level1/setv/setv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation @@ -68,8 +69,23 @@ static void typed_setv(char conjalpha, gtint_t n, T* alpha, T* x, gtint_t incx) template static void setv(char conjalpha, gtint_t n, T* alpha, T* x, gtint_t incx) { + +#ifdef TEST_UPPERCASE_ARGS + conjalpha = static_cast(std::toupper(static_cast(conjalpha))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjalpha_cpy = conjalpha; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/setv.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/setv.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS throw std::runtime_error("Error in testsuite/level1/setv.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED @@ -77,4 +93,15 @@ static void setv(char conjalpha, gtint_t n, T* alpha, T* x, gtint_t incx) #else throw std::runtime_error("Error in testsuite/level1/setv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjalpha", conjalpha, conjalpha_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "incx", incx, incx_cpy ); +#endif } diff --git a/gtestsuite/testsuite/level1/setv/ssetv_generic.cpp b/gtestsuite/testsuite/level1/setv/ssetv_generic.cpp index 2590619ea2..d608834b98 100644 --- a/gtestsuite/testsuite/level1/setv/ssetv_generic.cpp +++ b/gtestsuite/testsuite/level1/setv/ssetv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_setv.h" -class ssetvGenericTest : +class ssetvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssetvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssetvGeneric); -TEST_P( ssetvGenericTest, RandomData ) +TEST_P( ssetvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -61,33 +61,16 @@ TEST_P( ssetvGenericTest, RandomData ) test_setv( conjalpha, n, alpha, incx ); } -// Prints the test case combination -class ssetvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - std::string str_name = "bli_ssetv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ssetvGenericTest, + ssetvGeneric, ::testing::Combine( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)) // stride size for x ), - ::ssetvGenericTestPrint() + ::setvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/setv/test_setv.h b/gtestsuite/testsuite/level1/setv/test_setv.h index da98788ecc..cb1eacab3f 100644 --- a/gtestsuite/testsuite/level1/setv/test_setv.h +++ b/gtestsuite/testsuite/level1/setv/test_setv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -73,3 +73,21 @@ void test_setv( char conjalpha, gtint_t n, T alpha, gtint_t incx ) EXPECT_EQ(x[i], alpha_ref) << "blis_sol[" << i << "]="<< x[i] <<" ref = " << alpha_ref; } } + + +// Test-case logger : Used to print the test-case details based on parameters +class setvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjalpha = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjalpha_" + std::string(&conjalpha, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/setv/zsetv_generic.cpp b/gtestsuite/testsuite/level1/setv/zsetv_generic.cpp index d12271612f..b911e40ab9 100644 --- a/gtestsuite/testsuite/level1/setv/zsetv_generic.cpp +++ b/gtestsuite/testsuite/level1/setv/zsetv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,12 @@ #include #include "test_setv.h" -class zsetvGenericTest : +class zsetvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsetvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsetvGeneric); -TEST_P( zsetvGenericTest, RandomData ) +TEST_P( zsetvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -61,33 +61,16 @@ TEST_P( zsetvGenericTest, RandomData ) test_setv( conjalpha, n, alpha, incx ); } -// Prints the test case combination -class zsetvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - std::string str_name = "bli_zsetv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zsetvGenericTest, + zsetvGeneric, ::testing::Combine( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)) // stride size for x ), - ::zsetvGenericTestPrint() + ::setvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/subv/csubv_evt.cpp b/gtestsuite/testsuite/level1/subv/csubv_evt.cpp new file mode 100644 index 0000000000..9b36a380db --- /dev/null +++ b/gtestsuite/testsuite/level1/subv/csubv_evt.cpp @@ -0,0 +1,247 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_subv.h" + +class csubvEVT : + public ::testing::TestWithParam> {}; // yexval + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csubvEVT); + +TEST_P( csubvEVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for y + T yexval = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_subv( conj_x, n, incx, incy, xi, xexval, + yj, yexval, thresh ); +} + +#ifdef TEST_BLIS_TYPED + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +// Exception value testing(on X vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + csubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{NaN, -Inf}), + // index on y + ::testing::Values(gtint_t(0)), + // value on y + ::testing::Values(scomplex{0.0, 0.0}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on Y vector alone) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + csubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // index on x + ::testing::Values(gtint_t(0)), + // value on x + ::testing::Values(scomplex{0.0, 0.0}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X and Y vectors) with unit strides +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + csubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{NaN, -Inf}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X & Y vectors) with non-unit strides. +// The indices are such that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + csubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(50)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(5)), + // indices to set exception values on x + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), + // exception values to set on x + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{0.0, 0.0}, scomplex{NaN, -Inf}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), + // exception values to set on y + ::testing::Values(scomplex{NaN, 0.0}, scomplex{-Inf, 0.0}, + scomplex{0.0, Inf}, scomplex{-2.3, NaN}, + scomplex{4.5, -Inf}, scomplex{NaN, Inf}, + scomplex{0.0, 0.0}, scomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level1/subv/csubv_generic.cpp b/gtestsuite/testsuite/level1/subv/csubv_generic.cpp index 70797d5e5a..42911ba167 100644 --- a/gtestsuite/testsuite/level1/subv/csubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/csubv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,13 @@ #include #include "test_subv.h" -class csubvGenericTest : +class csubvGeneric : + // input params: x or conj(x), vector length, stride size of x, stride size of y public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csubvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csubvGeneric); -TEST_P( csubvGenericTest, RandomData ) +TEST_P( csubvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -57,7 +58,15 @@ TEST_P( csubvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +74,39 @@ TEST_P( csubvGenericTest, RandomData ) test_subv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class csubvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_csubv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - csubvGenericTest, + PositiveIncrements, + csubvGeneric, ::testing::Combine( - ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) ), - ::csubvGenericTestPrint() + ::subvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/subv/dsubv_evt.cpp b/gtestsuite/testsuite/level1/subv/dsubv_evt.cpp new file mode 100644 index 0000000000..7e6fd05089 --- /dev/null +++ b/gtestsuite/testsuite/level1/subv/dsubv_evt.cpp @@ -0,0 +1,229 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_subv.h" + +class dsubvEVT : + public ::testing::TestWithParam> {}; // yexval + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsubvEVT); + +TEST_P( dsubvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for y + T yexval = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_subv( conj_x, n, incx, incy, xi, xexval, + yj, yexval, thresh ); +} + +#ifdef TEST_BLIS_TYPED + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Exception value testing(on X vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + dsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf), + // index on y + ::testing::Values(gtint_t(0)), + // value on y + ::testing::Values(double(0.0)) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on Y vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + dsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // index on x + ::testing::Values(gtint_t(0)), + // value on x + ::testing::Values(double(0.0)), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X and Y vectors) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + dsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X & Y vectors) with non-unit strides. +// The indices are such that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + dsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(50)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(5)), + // indices to set exception values on x + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf, 0.0), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf, 0.0) + ), + ::subvEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp b/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp index 63a63a9274..3fbac80e1d 100644 --- a/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,13 @@ #include #include "test_subv.h" -class dsubvGenericTest : +class dsubvGeneric : + // input params : x or conj(x), vector length, stride size of x, stride size of y public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsubvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsubvGeneric); -TEST_P( dsubvGenericTest, RandomData ) +TEST_P( dsubvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -57,7 +58,14 @@ TEST_P( dsubvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +73,65 @@ TEST_P( dsubvGenericTest, RandomData ) test_subv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class dsubvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_dsubv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; +#ifdef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + PositiveIncrements, + dsubvGeneric, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) + ), + ::subvGenericPrint() + ); +#endif #ifdef TEST_BLIS_TYPED -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dsubvGenericTest, + PositiveIncrementforConjugate, + dsubvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: not transpose for x - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y + // c: conjugate for x + ::testing::Values('c'), + // n: size of vector. + // as conjugate of a real number x is x, + // so adding a single test that uses 'c' as an option for sanity check. + ::testing::Values( + gtint_t( 1),gtint_t( 7) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) ), - ::dsubvGenericTestPrint() + ::subvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/subv/ssubv_evt.cpp b/gtestsuite/testsuite/level1/subv/ssubv_evt.cpp new file mode 100644 index 0000000000..2c446cfd03 --- /dev/null +++ b/gtestsuite/testsuite/level1/subv/ssubv_evt.cpp @@ -0,0 +1,229 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_subv.h" + +class ssubvEVT : + public ::testing::TestWithParam> {}; // yexval + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssubvEVT); + +TEST_P( ssubvEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for y + T yexval = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_subv( conj_x, n, incx, incy, xi, xexval, + yj, yexval, thresh ); +} + +#ifdef TEST_BLIS_TYPED + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +// Exception value testing(on X vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + ssubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(10)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf), + // index on y + ::testing::Values(gtint_t(0)), + // value on y + ::testing::Values(float(0.0)) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on Y vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + ssubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // index on x + ::testing::Values(gtint_t(0)), + // value on x + ::testing::Values(float(0.0)), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X and Y vectors) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + ssubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X & Y vectors) with non-unit stridesi. +// The indices are such that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + ssubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(50)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(5)), + // indices to set exception values on x + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), + // exception values to set on x + ::testing::Values(NaN, -Inf, Inf, 0.0), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), + // exception values to set on y + ::testing::Values(NaN, -Inf, Inf, 0.0) + ), + ::subvEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp b/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp index 50e004cb07..c0ca7a5821 100644 --- a/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,13 @@ #include #include "test_subv.h" -class ssubvGenericTest : +class ssubvGeneric : + // input params: x or conj(x), vector length, stride size of x, stride size of y public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssubvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssubvGeneric); -TEST_P( ssubvGenericTest, RandomData ) +TEST_P( ssubvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -57,7 +58,14 @@ TEST_P( ssubvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +73,65 @@ TEST_P( ssubvGenericTest, RandomData ) test_subv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class ssubvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_ssubv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; +#ifdef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + PositiveIncrements, + ssubvGeneric, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n'), + // n: size of vector. + // as don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) + ), + ::subvGenericPrint() + ); +#endif #ifdef TEST_BLIS_TYPED -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ssubvGenericTest, + PositiveIncrementforConjugate, + ssubvGeneric, ::testing::Combine( - ::testing::Values('n'), // n: not transpose for x - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y + // c: conjugate for x + ::testing::Values('c'), + // n: size of vector. + // as conjugate of a real number x is x, + // so adding a single test that uses 'c' as an option for sanity check. + ::testing::Values( + gtint_t( 1),gtint_t( 7) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) ), - ::ssubvGenericTestPrint() + ::subvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/subv/subv.h b/gtestsuite/testsuite/level1/subv/subv.h index ff5059d6ff..ed6631e502 100644 --- a/gtestsuite/testsuite/level1/subv/subv.h +++ b/gtestsuite/testsuite/level1/subv/subv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation @@ -69,8 +70,32 @@ static void typed_subv(char conj_x, gtint_t n, T* x, gtint_t incx, T* y, gtint_t template static void subv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conjx_cpy = conjx; + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/subv.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/subv.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS throw std::runtime_error("Error in testsuite/level1/subv.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED @@ -78,4 +103,25 @@ static void subv(char conjx, gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) #else throw std::runtime_error("Error in testsuite/level1/subv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/subv/subv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/subv/subv_IIT_ERS.cpp new file mode 100644 index 0000000000..c65aa13255 --- /dev/null +++ b/gtestsuite/testsuite/level1/subv/subv_IIT_ERS.cpp @@ -0,0 +1,144 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_subv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class subv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(subv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLIS_TYPED) + +/* + BLIS Early Return Scenarios(ERS): + + SUBV is expected to return early in the following cases: + 1. n <= 0 +*/ + +// n < 0, with non-unit stride +TYPED_TEST(subv_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 5; + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS subv with a invalid value for n==-1 & non-unit stride inc = 5. + subv( 'n', invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n < 0, with unit stride +TYPED_TEST(subv_IIT_ERS, n_lt_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 1; + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS subv with a invalid value for n==-1 & unit stride inc = 1. + subv( 'n', invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n == 0, with non-unit stride +TYPED_TEST(subv_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 2; + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS subv with a invalid value for n==0 & non-unit stride inc = 2. + subv( 'n', invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n == 0, with unit stride +TYPED_TEST(subv_IIT_ERS, n_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 1; + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS subv with a invalid value for n==0 & unit stride inc = 1. + subv( 'n', invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} +#endif diff --git a/gtestsuite/testsuite/level1/subv/test_subv.h b/gtestsuite/testsuite/level1/subv/test_subv.h index ffdf86a3db..de94e1bcf1 100644 --- a/gtestsuite/testsuite/level1/subv/test_subv.h +++ b/gtestsuite/testsuite/level1/subv/test_subv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -66,5 +66,86 @@ void test_subv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } + +template +static void test_subv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, + gtint_t xi, T xexval, gtint_t yj, T yexval, + double thresh ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = xexval; + else return; + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < yj && yj < n ) y[yj * abs(incy)] = yexval; + else return; + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + testinghelpers::ref_subv( conjx, n, x.data(), incx, y_ref.data(), incy ); + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + subv( conjx, n, x.data(), incx, y.data(), incy ); + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y.data(), y_ref.data(), incy, thresh, true ); +} + + +// Test-case logger : Used to print the test-case details based on parameters +class subvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; + +template +class subvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + gtint_t xi = std::get<4>(str.param); + T xexval = std::get<5>(str.param); + gtint_t yj = std::get<6>(str.param); + T yexval = std::get<7>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + std::string xexval_str = testinghelpers::get_value_string(xexval); + std::string yexval_str = testinghelpers::get_value_string(yexval); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + xexval_str; + str_name = str_name + "_Y_" + std::to_string(yj); + str_name = str_name + "_" + yexval_str; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/subv/zsubv_evt.cpp b/gtestsuite/testsuite/level1/subv/zsubv_evt.cpp new file mode 100644 index 0000000000..6dc395cdb1 --- /dev/null +++ b/gtestsuite/testsuite/level1/subv/zsubv_evt.cpp @@ -0,0 +1,247 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_subv.h" + +class zsubvEVT : + public ::testing::TestWithParam> {}; // yexval + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsubvEVT); + +TEST_P( zsubvEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // index for exval in x + gtint_t xi = std::get<4>(GetParam()); + // exval for x + T xexval = std::get<5>(GetParam()); + // index for exval in y + gtint_t yj = std::get<6>(GetParam()); + // exval for y + T yexval = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_subv( conj_x, n, incx, incy, xi, xexval, + yj, yexval, thresh ); +} + +#ifdef TEST_BLIS_TYPED + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Exception value testing(on X vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecX_unitStrides, + zsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{NaN, -Inf}), + // index on y + ::testing::Values(gtint_t(0)), + // value on y + ::testing::Values(dcomplex{0.0, 0.0}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on Y vector alone) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecY_unitStrides, + zsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // index on x + ::testing::Values(gtint_t(0)), + // value on x + ::testing::Values(dcomplex{0.0, 0.0}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X and Y vectors) with unit strides on zen3 +INSTANTIATE_TEST_SUITE_P( + vecXY_unitStrides, + zsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(100)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1)), + // indices to set exception values on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{NaN, -Inf}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(7), + gtint_t(19), gtint_t(27), gtint_t(38), + gtint_t(69), gtint_t(99)), + // exception values to set on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); + +// Exception value testing(on X & Y vectors) with non-unit strides. +// The indices are such that we cover _vecX_, _vecY_ and _vecXY_ cases together. +INSTANTIATE_TEST_SUITE_P( + vecXY_nonUnitStrides, + zsubvEVT, + ::testing::Combine( + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as we don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t(50)), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3)), + // incy: stride of y vector. + ::testing::Values( + gtint_t(5)), + // indices to set exception values on x + ::testing::Values(gtint_t(1), gtint_t(27), gtint_t(49)), + // exception values to set on x + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{0.0, 0.0}, dcomplex{NaN, -Inf}), + // indices to set exception values on y + ::testing::Values(gtint_t(0), gtint_t(26), gtint_t(49)), + // exception values to set on y + ::testing::Values(dcomplex{NaN, 0.0}, dcomplex{-Inf, 0.0}, + dcomplex{0.0, Inf}, dcomplex{-2.3, NaN}, + dcomplex{4.5, -Inf}, dcomplex{NaN, Inf}, + dcomplex{0.0, 0.0}, dcomplex{NaN, -Inf}) + ), + ::subvEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp b/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp index f4e634f4c5..91b6cb8113 100644 --- a/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,12 +35,13 @@ #include #include "test_subv.h" -class zsubvGenericTest : +class zsubvGeneric : + // input params: x or conj(x), vector length, stride size of x, stride size of y public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsubvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsubvGeneric); -TEST_P( zsubvGenericTest, RandomData ) +TEST_P( zsubvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -57,7 +58,15 @@ TEST_P( zsubvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); + // Check gtestsuite subv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -65,37 +74,39 @@ TEST_P( zsubvGenericTest, RandomData ) test_subv( conj_x, n, incx, incy, thresh ); } -// Prints the test case combination -class zsubvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - std::string str_name = "bli_zsubv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zsubvGenericTest, + PositiveIncrements, + zsubvGeneric, ::testing::Combine( - ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y + // n: use x, c: use conj(x) + ::testing::Values('n','c'), + // n: size of vector. + // as don't have BLIS vectorized kernels for subv, + // having fewer sizes or maybe a Range would be sufficient + // to ensure code coverage of the reference kernel. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1),gtint_t(5) + ) ), - ::zsubvGenericTestPrint() + ::subvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/swapv/cswapv_generic.cpp b/gtestsuite/testsuite/level1/swapv/cswapv_generic.cpp new file mode 100644 index 0000000000..c046486691 --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/cswapv_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv.h" + +class cswapvGeneric : + // input params : vector length, stride size of x, stride size of y + public ::testing::TestWithParam> {}; + +TEST_P( cswapvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // stride size for y: + gtint_t incy = std::get<2>(GetParam()); + + using T = scomplex; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv( n, incx, incy ); +} + +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + cswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(50), + gtint_t(100) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::swapvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + cswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(500), gtint_t(-300) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(100), gtint_t(-200) + ) + ), + ::swapvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/swapv/dswapv_generic.cpp b/gtestsuite/testsuite/level1/swapv/dswapv_generic.cpp new file mode 100644 index 0000000000..f893773cdd --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/dswapv_generic.cpp @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv.h" + +class dswapvGeneric : + // input params : vector length, stride size of x, stride size of y + public ::testing::TestWithParam> {}; + +TEST_P( dswapvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // stride size for y: + gtint_t incy = std::get<2>(GetParam()); + + using T = double; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv( n, incx, incy ); +} + +/*************************************************************************/ +/* When n values are 32, 16, 8, 4 it is avx2 optimised */ +/* Values to be tested to cover all loops */ +/* 1, 2, 4, 8, 16, 32, 64, 128 : L1, L1*2, L4, L8, L16, L32, L64, 2*L64 */ +/* 5, 9, 17, 33, 65, 129 : L1 + ( L4, L8, L16, L32, L64, 2*L64) */ +/* 6, 10, 18, 34, 68, 130 : L1*2 + (L4, L8, L16, L32, L64, 2*L64) */ +/* 12, 24, 40, 72, 136 : L8 + (L4, L16, L32, L64, 2*L64) */ +/* 20, 136 : L16 + (L4, 2*L64) */ +/* 36, 96, 160 : L32 +(L4, L8, L32, L64, 2*L64) */ +/*************************************************************************/ +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + dswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), gtint_t(2), gtint_t(4), gtint_t(8), gtint_t(16), gtint_t(32), + gtint_t(64), gtint_t(128), gtint_t(5), gtint_t(9), gtint_t(17), gtint_t(33), + gtint_t(65), gtint_t(129), gtint_t(6), gtint_t(10), gtint_t(18), gtint_t(34), + gtint_t(68), gtint_t(130), gtint_t(12), gtint_t(24), gtint_t(40), gtint_t(72), + gtint_t(136), gtint_t(20), gtint_t(36), gtint_t(96), gtint_t(160) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::swapvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + dswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(500), gtint_t(-600) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(100), gtint_t(-500) + ) + ), + ::swapvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/swapv/sswapv_generic.cpp b/gtestsuite/testsuite/level1/swapv/sswapv_generic.cpp new file mode 100644 index 0000000000..3522513908 --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/sswapv_generic.cpp @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv.h" + +class sswapvGeneric : + // input params : vector length, stride size of x, stride size of y + public ::testing::TestWithParam> {}; + +TEST_P( sswapvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // stride size for y: + gtint_t incy = std::get<2>(GetParam()); + + using T = float; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv( n, incx, incy ); +} + +/*****************************************************************/ +/* When n values are 64, 32, 16, 8 it is avx2 optimised */ +/* Values to be tested to cover all loops */ +/* 1, 2, 8, 16, 32, 64, 128 : L1, L1*2 L8, L16, L32, L64, 2*L64 */ +/* 2, 9, 17, 33, 65, 129 : L1 + (L1, L8, L16, L32, L64, 2*L64) */ +/* 10, 18, 34, 68, 130 : L1*2 + (L8, L16, L32, L64, 2*L64) */ +/* 24, 40, 72, 136 : L8 + (L16, L32, L64, 2*L64) */ +/* 24, 40, 72, 136 : L16 + (L16, L32, L64, 2*L64) */ +/* 96, 160 : L32 + (L64, 2*L64) */ +/*****************************************************************/ +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + sswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), gtint_t(2), gtint_t(8), gtint_t(16), gtint_t(32), + gtint_t(64), gtint_t(128), gtint_t(9), gtint_t(17), gtint_t(33), + gtint_t(65), gtint_t(129), gtint_t(10), gtint_t(18), gtint_t(34), + gtint_t(68), gtint_t(130), gtint_t(24), gtint_t(40), gtint_t(72), + gtint_t(136), gtint_t(96), gtint_t(160) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::swapvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + sswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(100), gtint_t(-300) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(500), gtint_t(-200) + ) + ), + ::swapvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/swapv/swapv.h b/gtestsuite/testsuite/level1/swapv/swapv.h new file mode 100644 index 0000000000..5e1740b22c --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/swapv.h @@ -0,0 +1,148 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +/** + * @brief Performs the operation: + * x <=> y + * @param[in] n vector length of x and y + * @param[in,out] x pointer which points to the first element of x + * @param[in,out] y pointer which points to the first element of y + * @param[in] incx increment of x + * @param[in] incy increment of y + */ + +template +static void swapv_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + + if constexpr (std::is_same::value) + sswap_( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + dswap_( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + cswap_( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + zswap_( &n, x, &incx, y, &incy ); + else + throw std::runtime_error("Error in testsuite/level1/swapv.h: Invalid typename in swapv_()."); +} + +template +static void swapv_blis_impl(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + + if constexpr (std::is_same::value) + sswap_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + dswap_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + cswap_blis_impl( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + zswap_blis_impl( &n, x, &incx, y, &incy ); + else + throw std::runtime_error("Error in testsuite/level1/swapv.h: Invalid typename in swapv_blis_impl()."); +} + +template +static void cblas_swapv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + + if constexpr (std::is_same::value) + cblas_sswap( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + cblas_dswap( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + cblas_cswap( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + cblas_zswap( n, x, incx, y, incy ); + else + throw std::runtime_error("Error in testsuite/level1/swapv.h: Invalid typename in cblas_swapv()."); +} + +template +static void typed_swapv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + if constexpr (std::is_same::value) + bli_sswapv( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_dswapv( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_cswapv( n, x, incx, y, incy ); + else if constexpr (std::is_same::value) + bli_zswapv( n, x, incx, y, incy ); + else + throw std::runtime_error("Error in testsuite/level1/swapv.h: Invalid typename in typed_swapv()."); + +} + +template +static void swapv(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy) +{ + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; +#endif + +#ifdef TEST_BLAS + swapv_( n, x, incx, y, incy ); +#elif TEST_BLAS_BLIS_IMPL + swapv_blis_impl( n, x, incx, y, incy ); +#elif TEST_CBLAS + cblas_swapv( n, x, incx, y, incy ); +#elif TEST_BLIS_TYPED + typed_swapv( n, x, incx, y, incy ); +#else + throw std::runtime_error("Error in testsuite/level1/swapv.h: No interfaces are set to be tested."); +#endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); +#endif +} + diff --git a/gtestsuite/testsuite/level1/swapv/swapv_IIT_ERS.cpp b/gtestsuite/testsuite/level1/swapv/swapv_IIT_ERS.cpp new file mode 100644 index 0000000000..6b214b548a --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/swapv_IIT_ERS.cpp @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class swapv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(swapv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + BLIS Early Return Scenarios(ERS): + + swapv is expected to return early in the following cases: + 1. n <= 0 +*/ + +// n < 0, with non-unit stride +TYPED_TEST(swapv_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 5; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + swapv( invalid_n, nullptr, inc, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS swapv with a invalid value for n==-1 & non-unit stride inc = 5. + swapv( invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n < 0, with unit stride +TYPED_TEST(swapv_IIT_ERS, n_lt_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 1; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + swapv( invalid_n, nullptr, inc, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS swapv with a invalid value for n==-1 & unit stride inc = 1. + swapv( invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n == 0, with non-unit stride +TYPED_TEST(swapv_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 2; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + swapv( invalid_n, nullptr, inc, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS swapv with a invalid value for n==0 & non-unit stride inc = 2. + swapv( invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +// n == 0, with unit stride +TYPED_TEST(swapv_IIT_ERS, n_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 1; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + swapv( invalid_n, nullptr, inc, nullptr, inc ); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the X & Y vectors with values for debugging purposes + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + std::vector y = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Copy so that we check that the elements of Y are not modified. + std::vector y_ref(y); + + // Call BLIS swapv with a invalid value for n==0 & unit stride inc = 1. + swapv( invalid_n, x.data(), inc, y.data(), inc ); + + // Use bitwise comparison (no threshold). + computediff( "y", N, y.data(), y_ref.data(), inc ); +} + +#endif diff --git a/gtestsuite/testsuite/level1/swapv/test_swapv.h b/gtestsuite/testsuite/level1/swapv/test_swapv.h new file mode 100644 index 0000000000..852672deec --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/test_swapv.h @@ -0,0 +1,86 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "swapv.h" +#include "inc/check_error.h" + +/** + * @brief Generic test body for swapv operation. + */ +template +static void test_swapv( gtint_t n, gtint_t incx, gtint_t incy ) +{ + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -50, 50, n, incx ); + std::vector y = testinghelpers::get_random_vector( 60, 100, n, incy ); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + // Create a copy of y so that we can check reference results. + std::vector x_ref(x); + std::vector y_ref(y); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + swapv( n, x.data(), incx, y.data(), incy ); + + //---------------------------------------------------------- + // Compute binary comparison + //---------------------------------------------------------- + computediff( n, x.data(), x_ref.data(), y.data(), y_ref.data(), incx, incy, false ); + +} + +// Test-case logger : Used to print the test-case details based on parameters +class swapvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + gtint_t incy = std::get<2>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/swapv/zswapv_generic.cpp b/gtestsuite/testsuite/level1/swapv/zswapv_generic.cpp new file mode 100644 index 0000000000..3d0ce417c0 --- /dev/null +++ b/gtestsuite/testsuite/level1/swapv/zswapv_generic.cpp @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv.h" + +class zswapvGeneric : + // input params : vector length, stride size of x, stride size of y + public ::testing::TestWithParam> {}; + +TEST_P( zswapvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // stride size for y: + gtint_t incy = std::get<2>(GetParam()); + + using T = dcomplex; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv( n, incx, incy ); +} + +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + zswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(50), + gtint_t(100) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::swapvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + zswapvGeneric, + ::testing::Combine( + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(500), gtint_t(-100) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(100), gtint_t(-200) + ) + ), + ::swapvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp index 6fb81b92aa..b2b28feeb1 100644 --- a/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_xpbyv.h" -class cxpbyvGenericTest : +class cxpbyvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cxpbyvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cxpbyvGeneric); // Tests using random integers as vector elements. -TEST_P( cxpbyvGenericTest, RandomData ) +TEST_P( cxpbyvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -64,45 +64,31 @@ TEST_P( cxpbyvGenericTest, RandomData ) T beta = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite xpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + thresh = 0.0; + else if (beta == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class cxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - scomplex beta = std::get<4>(str.param); - std::string str_name = "bli_cxpbyv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of cxpby. INSTANTIATE_TEST_SUITE_P( Blackbox, - cxpbyvGenericTest, + cxpbyvGeneric, ::testing::Combine( ::testing::Values('n', 'c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -110,7 +96,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // beta ), - ::cxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); // Test for non-unit increments. @@ -118,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - cxpbyvGenericTest, + cxpbyvGeneric, ::testing::Combine( ::testing::Values('n', 'c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -126,6 +112,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y ::testing::Values(scomplex{4.0, 3.1}) // beta ), - ::cxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp index 079867f1f4..eb84a829c9 100644 --- a/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_xpbyv.h" -class dxpbyvGenericTest : +class dxpbyvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dxpbyvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dxpbyvGeneric); // Tests using random integers as vector elements. -TEST_P( dxpbyvGenericTest, RandomData ) +TEST_P( dxpbyvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -64,7 +64,18 @@ TEST_P( dxpbyvGenericTest, RandomData ) T beta = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite xpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + thresh = 0.0; + else if (beta == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,37 +83,11 @@ TEST_P( dxpbyvGenericTest, RandomData ) test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class dxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - double beta = std::get<4>(str.param); - std::string str_name = "bli_dxpbyv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of caxpy. INSTANTIATE_TEST_SUITE_P( Blackbox, - dxpbyvGenericTest, + dxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -110,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(double(2.0), double(-2.0)) // beta ), - ::dxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); @@ -119,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - dxpbyvGenericTest, + dxpbyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector @@ -127,7 +112,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(double(2.0)) // beta ), - ::dxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); @@ -136,7 +121,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - dxpbyvGenericTest, + dxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector @@ -144,6 +129,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y ::testing::Values(double(4.0)) // beta ), - ::dxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp index fe33a81cb8..a8fc3f4780 100644 --- a/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_xpbyv.h" -class sxpbyvGenericTest : +class sxpbyvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sxpbyvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sxpbyvGeneric); // Tests using random integers as vector elements. -TEST_P( sxpbyvGenericTest, RandomData ) +TEST_P( sxpbyvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -64,7 +64,18 @@ TEST_P( sxpbyvGenericTest, RandomData ) T beta = std::get<4>(GetParam()); // Set the threshold for the errors: - float thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite xpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + thresh = 0.0; + else if (beta == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters @@ -72,37 +83,11 @@ TEST_P( sxpbyvGenericTest, RandomData ) test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class sxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - float beta = std::get<4>(str.param); - std::string str_name = "bli_sxpbyv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of caxpy. INSTANTIATE_TEST_SUITE_P( Blackbox, - sxpbyvGenericTest, + sxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -110,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(float(2.0), float(-2.0)) // beta ), - ::sxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); // Test when conjugate of x is used as an argument. This option is BLIS-api specific. @@ -118,7 +103,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( ConjX, - sxpbyvGenericTest, + sxpbyvGeneric, ::testing::Combine( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector @@ -126,7 +111,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(float(2.0)) // beta ), - ::sxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); @@ -135,7 +120,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - sxpbyvGenericTest, + sxpbyvGeneric, ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector @@ -143,6 +128,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y ::testing::Values(float(4.0)) // beta ), - ::sxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h b/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h index 1694c2149d..c6be42f729 100644 --- a/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h +++ b/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -67,5 +67,26 @@ static void test_xpbyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); } + +// Test-case logger : Used to print the test-case details based on parameters +template +class xpbyvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + T beta = std::get<4>(str.param); + std::string str_name = "bli_cxpbyv"; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level1/xpbyv/xpbyv.h b/gtestsuite/testsuite/level1/xpbyv/xpbyv.h index 2b3a15fbd5..4e32e66525 100644 --- a/gtestsuite/testsuite/level1/xpbyv/xpbyv.h +++ b/gtestsuite/testsuite/level1/xpbyv/xpbyv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -70,8 +71,33 @@ static void typed_xpbyv(char conj_x, gtint_t n, T* x, gtint_t incx, T beta, T* y template static void xpbyv(char conj_x, gtint_t n, T* x, gtint_t incx, T beta, T* y, gtint_t incy) { + +#ifdef TEST_UPPERCASE_ARGS + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char conj_x_cpy = conj_x; + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + T beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level1/xpbyv.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level1/xpbyv.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS throw std::runtime_error("Error in testsuite/level1/xpbyv.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED @@ -79,4 +105,26 @@ static void xpbyv(char conj_x, gtint_t n, T* x, gtint_t incx, T beta, T* y, gtin #else throw std::runtime_error("Error in testsuite/level1/xpbyv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "beta", beta, beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp index 04b781da8c..f2c36fd4ec 100644 --- a/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,17 +35,17 @@ #include #include "test_xpbyv.h" -class zxpbyvGenericTest : +class zxpbyvGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zxpbyvGenericTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zxpbyvGeneric); // Tests using random integers as vector elements. -TEST_P( zxpbyvGenericTest, RandomData ) +TEST_P( zxpbyvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -64,45 +64,31 @@ TEST_P( zxpbyvGenericTest, RandomData ) T beta = std::get<4>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + // Check gtestsuite xpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (beta == testinghelpers::ZERO()) + thresh = 0.0; + else if (beta == testinghelpers::ONE()) + thresh = testinghelpers::getEpsilon(); + else + thresh = 2*testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } -// Used to generate a test case with a sensible name. -// Beware that we cannot use fp numbers (e.g., 2.3) in the names, -// so we are only printing int(2.3). This should be enough for debugging purposes. -// If this poses an issue, please reach out. -class zxpbyvGenericTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); - dcomplex beta = std::get<4>(str.param); - std::string str_name = "bli_zxpbyv"; - str_name += "_" + std::to_string(n); - str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name += "_" + incy_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_b" + beta_str; - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of zaxpby. INSTANTIATE_TEST_SUITE_P( Blackbox, - zxpbyvGenericTest, + zxpbyvGeneric, ::testing::Combine( ::testing::Values('n', 'c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -110,7 +96,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}) // beta ), - ::zxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); // Test for non-unit increments. @@ -118,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( // We can modify the values using implementantion details. INSTANTIATE_TEST_SUITE_P( NonUnitIncrements, - zxpbyvGenericTest, + zxpbyvGeneric, ::testing::Combine( ::testing::Values('n', 'c'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. @@ -126,6 +112,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y ::testing::Values(dcomplex{4.0, 3.1}) // beta ), - ::zxpbyvGenericTestPrint() + ::xpbyvGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level2/gemv/IIT_ERS/gemv_IIT_ERS.cpp b/gtestsuite/testsuite/level2/gemv/IIT_ERS/gemv_IIT_ERS.cpp new file mode 100644 index 0000000000..09a4591a59 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/IIT_ERS/gemv_IIT_ERS.cpp @@ -0,0 +1,795 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class gemv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(gemv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) +TYPED_TEST(gemv_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemv( 'x', TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( 'x', TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM): + 1. When TRANS != 'N' || TRANS != 'T' || TRANS != 'C' (info = 1) + 2. When m < 0 (info = 2) + 3. When n < 0 (info = 3) + 4. When lda < m (info = 6) + 5. When incx = 0 (info = 8) + 6. When incy = 0 (info = 11) + +*/ + +TYPED_TEST(gemv_IIT_ERS, invalid_trans) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, 'p', CONJ, M, N, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, 'p', CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, 'p', CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +TYPED_TEST(gemv_IIT_ERS, m_lt_zero) +{ + using T = TypeParam; + gtint_t invalid_m = -1; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, invalid_m, N, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif +} + +TYPED_TEST(gemv_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, M, invalid_n, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif +} + +TYPED_TEST(gemv_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, M, N, nullptr, nullptr, LDA - 1, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA - 1, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA - 1, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif +} + +TYPED_TEST(gemv_IIT_ERS, incx_eq_zero) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, M, N, nullptr, nullptr, LDA, + nullptr, 0, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, 0, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), 0, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif +} + +TYPED_TEST(gemv_IIT_ERS, incy_eq_zero) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, M, N, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, 0 ); +#else + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, 0 ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 11 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of y so that we can check reference results. + std::vector y_ref(y); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), 0 ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 11 ); +#endif +} + +/* + BLAS Early Return Scenarios(ERS): + + GEMV is expected to return early in the following cases: + 1. m || n = 0 + 2. alpha = 0 && beta = 1 +*/ + +// m = 0 +TYPED_TEST(gemv_IIT_ERS, m_eq_zero) +{ + using T = TypeParam; + gtint_t invalid_m = 0; + gtint_t incx = 2; + gtint_t incy = 3; + + T alpha = T{1.3}; + T beta = T{0.7}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, invalid_m, N, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// n = 0 +TYPED_TEST(gemv_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t incx = 1; + gtint_t incy = 1; + + T alpha = T{1.3}; + T beta = T{0.7}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, M, invalid_n, nullptr, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// m = 0, with unit alpha +TYPED_TEST(gemv_IIT_ERS, m_eq_zero_Unitbeta) +{ + using T = TypeParam; + gtint_t invalid_m = 0; + gtint_t incx = 2; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, nullptr, LDA, + nullptr, incx, nullptr, nullptr, incy ); +#else + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, invalid_m, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// n = 0, with unit alpha and beta +TYPED_TEST(gemv_IIT_ERS, n_eq_zero_UnitAlphaBeta) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t incx = 1; + gtint_t incy = 1; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, invalid_n, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and unit beta +TYPED_TEST(gemv_IIT_ERS, ZeroAlpha_UnitBeta) +{ + using T = TypeParam; + gtint_t incx = 1; + gtint_t incy = 1; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, nullptr, incy ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 1, 3, N, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, M, incy ); + + // Create a copy of c so that we can check reference results. + std::vector y_ref(y); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and zero beta - set y to zero +TYPED_TEST(gemv_IIT_ERS, ZeroAlpha_ZeroBeta) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initzero( beta ); + + std::vector y = testinghelpers::get_random_vector( 0, 1, N, incy ); + std::vector y2(y); + // Create a zero vector, since the output for alpha = beta = 0 should be a + // zero vector. + std::vector zero_vec = testinghelpers::get_random_vector( 0, 0, N, incy ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, y2.data(), incy ); + computediff( "y", N, y2.data(), zero_vec.data(), incy); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 0, 1, M, incx ); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), zero_vec.data(), incy); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and non-zero/non-unit beta - scale y only +TYPED_TEST(gemv_IIT_ERS, ZeroAlpha_OtherBeta) +{ + using T = TypeParam; + gtint_t incx = 3; + gtint_t incy = 3; + + T alpha, beta; + testinghelpers::initzero( alpha ); + beta = T{2.0}; + double thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector( 0, 1, M, incx ); + std::vector y = testinghelpers::get_random_vector( 0, 1, N, incy ); + std::vector y_ref(y); + std::vector y2(y); + + testinghelpers::ref_gemv( STORAGE, TRANS, CONJ, M, N, alpha, a.data(), LDA, + x.data(), incx, beta, y_ref.data(), incy ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, nullptr, LDA, + nullptr, incx, &beta, y2.data(), incy ); + + computediff( "y", N, y2.data(), y_ref.data(), incy, thresh); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemv( STORAGE, TRANS, CONJ, M, N, &alpha, a.data(), LDA, + x.data(), incx, &beta, y.data(), incy ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "y", N, y.data(), y_ref.data(), incy, thresh); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif diff --git a/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_evt.cpp b/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_evt.cpp new file mode 100644 index 0000000000..293d53341b --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_evt.cpp @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = scomplex; +using RT = testinghelpers::type_info::real_type; +static RT AOCL_NaN = std::numeric_limits::quiet_NaN(); +static RT AOCL_Inf = std::numeric_limits::infinity(); + +class cgemvEVT : + public ::testing::TestWithParam> {}; // lda_inc + +TEST_P( cgemvEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // exception value for a: + T a_exval = std::get<9>(GetParam()); + // exception value for x: + T x_exval = std::get<10>(GetParam()); + // exception value for y: + T y_exval = std::get<11>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<12>(GetParam()); + + bool is_memory_test = false; + bool is_evt_test = true; + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test, is_evt_test, a_exval, x_exval, y_exval ); +} + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_unitStride, + cgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // alpha + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // a_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // x_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_nonUnitStride, + cgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // alpha + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // a_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // x_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_unitStride, + cgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // alpha + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(T{0.0, 0.0}), // a_exval + ::testing::Values(T{0.0, 0.0}), // x_exval + ::testing::Values(T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_nonUnitStride, + cgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // alpha + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(T{0.0, 0.0}), // a_exval + ::testing::Values(T{0.0, 0.0}), // x_exval + ::testing::Values(T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_generic.cpp new file mode 100644 index 0000000000..83744db1d4 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/cgemv/cgemv_generic.cpp @@ -0,0 +1,246 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = scomplex; + +class cgemvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +TEST_P( cgemvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + cgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(20), 1), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + cgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // m + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +#if 1 +INSTANTIATE_TEST_SUITE_P( + Blackbox_Large, + cgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(2127)), // m + ::testing::Values(gtint_t(2127)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeM, + cgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(5099)), // m + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeN, + cgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // m + ::testing::Values(gtint_t(5099)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp deleted file mode 100644 index 8ba1f7a429..0000000000 --- a/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemv.h" - -class cgemvTest : - public ::testing::TestWithParam> {}; - -TEST_P(cgemvTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // specifies beta value - T beta = std::get<6>(GetParam()); - // stride size for x: - gtint_t incx = std::get<7>(GetParam()); - // stride size for y: - gtint_t incy = std::get<8>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); -} - -class cgemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char transa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - scomplex beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cgemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cgemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cgemv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + transa+conjx; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - cgemvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','c','t'), // transa - ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, -2.0}), // alpha - ::testing::Values(scomplex{-1.0, 1.0}), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)) // increment to the leading dim of a - ), - ::cgemvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_evt.cpp b/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_evt.cpp new file mode 100644 index 0000000000..b608772418 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_evt.cpp @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = double; +static T AOCL_NaN = std::numeric_limits::quiet_NaN(); +static T AOCL_Inf = std::numeric_limits::infinity(); + +class dgemvEVT : + public ::testing::TestWithParam> {}; // lda_inc + +TEST_P( dgemvEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // exception value for a: + T a_exval = std::get<9>(GetParam()); + // exception value for x: + T x_exval = std::get<10>(GetParam()); + // exception value for y: + T y_exval = std::get<11>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<12>(GetParam()); + + bool is_memory_test = false; + bool is_evt_test = true; + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test, is_evt_test, a_exval, x_exval, y_exval ); +} + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_unitStride, + dgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // alpha + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // a_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // x_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_nonUnitStride, + dgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // alpha + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // a_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // x_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_unitStride, + dgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // alpha + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(0), // a_exval + ::testing::Values(0), // x_exval + ::testing::Values(0), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_nonUnitStride, + dgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // alpha + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(0), // a_exval + ::testing::Values(0), // x_exval + ::testing::Values(0), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_generic.cpp new file mode 100644 index 0000000000..fbaa5860cb --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/dgemv/dgemv_generic.cpp @@ -0,0 +1,235 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = double; + +class dgemvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +TEST_P( dgemvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(20), 1), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + dgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // m + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +#if 1 +INSTANTIATE_TEST_SUITE_P( + Blackbox_Large, + dgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(2127)), // m + ::testing::Values(gtint_t(2127)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeM, + dgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(5099)), // m + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeN, + dgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // m + ::testing::Values(gtint_t(5099)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp deleted file mode 100644 index 33cc9fa57b..0000000000 --- a/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemv.h" - -class dgemvTest : - public ::testing::TestWithParam> {}; - -TEST_P(dgemvTest, RandomData) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // specifies beta value - T beta = std::get<6>(GetParam()); - // stride size for x: - gtint_t incx = std::get<7>(GetParam()); - // stride size for y: - gtint_t incy = std::get<8>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); -} - -class dgemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char transa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - double beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dgemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dgemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dgemv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + transa+conjx; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - dgemvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','t'), // transa - ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 ), // alpha - ::testing::Values(-1.0 ), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)) // increment to the leading dim of a - ), - ::dgemvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/gemv/gemv.h b/gtestsuite/testsuite/level2/gemv/gemv.h index d6cc12f2db..06a167a3b0 100644 --- a/gtestsuite/testsuite/level2/gemv/gemv.h +++ b/gtestsuite/testsuite/level2/gemv/gemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -76,6 +77,22 @@ static void gemv_( char transa, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t l throw std::runtime_error("Error in testsuite/level2/gemv.h: Invalid typename in gemv_()."); } +template +static void gemv_blis_impl( char transa, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, + T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) +{ + if constexpr (std::is_same::value) + sgemv_blis_impl( &transa, &m, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else if constexpr (std::is_same::value) + dgemv_blis_impl( &transa, &m, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else if constexpr (std::is_same::value) + cgemv_blis_impl( &transa, &m, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else if constexpr (std::is_same::value) + zgemv_blis_impl( &transa, &m, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else + throw std::runtime_error("Error in testsuite/level2/gemv.h: Invalid typename in gemv_blis_impl()."); +} + template static void cblas_gemv( char storage, char trans, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) @@ -135,11 +152,57 @@ template static void gemv( char storage, char trans, char conj_x, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + trans = static_cast(std::toupper(static_cast(trans))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char trans_cpy = trans; + char conj_x_cpy = conj_x; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + T* beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, trans, m, n, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* xp_cpy = nullptr; + gtint_t size_xp; + if(( trans == 'n' ) || ( trans == 'N' )) + size_xp = testinghelpers::buff_dim( n, incx ); + else + size_xp = testinghelpers::buff_dim( m, incx ); + if (xp && size_xp > 0) + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) gemv_( trans, m, n, alpha, ap, lda, xp, incx, beta, yp, incy ); else throw std::runtime_error("Error in testsuite/level2/gemv.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + gemv_blis_impl( trans, m, n, alpha, ap, lda, xp, incx, beta, yp, incy ); + else + throw std::runtime_error("Error in testsuite/level2/gemv.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_gemv( storage, trans, m, n, alpha, ap, lda, xp, incx, beta, yp, incy ); #elif TEST_BLIS_TYPED @@ -147,4 +210,40 @@ static void gemv( char storage, char trans, char conj_x, gtint_t m, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/gemv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "trans", trans, trans_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, m, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (xp && size_xp > 0) + { + if(( trans == 'n' ) || ( trans == 'N' )) + computediff( "x", n, xp, xp_cpy, incx, true ); + else + computediff( "x", m, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_evt.cpp b/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_evt.cpp new file mode 100644 index 0000000000..afaf238272 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_evt.cpp @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = float; +static T AOCL_NaN = std::numeric_limits::quiet_NaN(); +static T AOCL_Inf = std::numeric_limits::infinity(); + +class sgemvEVT : + public ::testing::TestWithParam> {}; // lda_inc + +TEST_P( sgemvEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // exception value for a: + T a_exval = std::get<9>(GetParam()); + // exception value for x: + T x_exval = std::get<10>(GetParam()); + // exception value for y: + T y_exval = std::get<11>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<12>(GetParam()); + + bool is_memory_test = false; + bool is_evt_test = true; + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test, is_evt_test, a_exval, x_exval, y_exval ); +} + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_unitStride, + sgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // alpha + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // a_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // x_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_nonUnitStride, + sgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // alpha + ::testing::Values(-1.0, 0.0, 1.0, 2.3), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // a_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // x_exval + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf, 0), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_unitStride, + sgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // alpha + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(0), // a_exval + ::testing::Values(0), // x_exval + ::testing::Values(0), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_nonUnitStride, + sgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // alpha + ::testing::Values(AOCL_NaN, AOCL_Inf, -AOCL_Inf), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(0), // a_exval + ::testing::Values(0), // x_exval + ::testing::Values(0), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_generic.cpp new file mode 100644 index 0000000000..fabc307a77 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/sgemv/sgemv_generic.cpp @@ -0,0 +1,236 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = float; + +class sgemvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +TEST_P( sgemvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test ); +} + + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + sgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(20), 1), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + sgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // m + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +#if 1 +INSTANTIATE_TEST_SUITE_P( + Blackbox_Large, + sgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(2127)), // m + ::testing::Values(gtint_t(2127)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeM, + sgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(5099)), // m + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeN, + sgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // m + ::testing::Values(gtint_t(5099)), // n + ::testing::Values( 0.0, 1.0, -1.0, -1.2 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.1 ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp deleted file mode 100644 index ec726ff56b..0000000000 --- a/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemv.h" - -class sgemvTest : - public ::testing::TestWithParam> {}; - -TEST_P(sgemvTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // specifies beta value - T beta = std::get<6>(GetParam()); - // stride size for x: - gtint_t incx = std::get<7>(GetParam()); - // stride size for y: - gtint_t incy = std::get<8>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); -} - -class sgemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char transa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - float beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sgemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sgemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sgemv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + transa+conjx; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - sgemvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','t'), // transa - ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 ), // alpha - ::testing::Values(-1.0 ), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)) // increment to the leading dim of a - ), - ::sgemvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/gemv/test_gemv.h b/gtestsuite/testsuite/level2/gemv/test_gemv.h index 76f8970294..d0ed9fa317 100644 --- a/gtestsuite/testsuite/level2/gemv/test_gemv.h +++ b/gtestsuite/testsuite/level2/gemv/test_gemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -37,44 +37,188 @@ #include "gemv.h" #include "level2/ref_gemv.h" #include "inc/check_error.h" +#include "common/testing_helpers.h" #include #include template - -void test_gemv( char storage, char trnsa, char conjx, gtint_t m, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, double thresh ) +void test_gemv( char storage, char transa, char conjx, gtint_t m, gtint_t n, + T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, + double thresh, bool is_memory_test = false, + bool is_evt_test = false, T a_exval = T{0}, T x_exval = T{0}, + T y_exval = T{0} ) { // Compute the leading dimensions for matrix size calculation. gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + dim_t size_a = testinghelpers::matsize( storage, 'n', m, n, lda ) * sizeof(T); + testinghelpers::ProtectedBuffer a_buf(size_a, false, is_memory_test); + testinghelpers::datagenerators::randomgenerators( 1, 5, storage, m, n, (T*)(a_buf.greenzone_1), 'n', lda ); + // Get correct vector lengths. - gtint_t lenx = ( testinghelpers::chknotrans( trnsa ) ) ? n : m ; - gtint_t leny = ( testinghelpers::chknotrans( trnsa ) ) ? m : n ; + gtint_t lenx = ( testinghelpers::chknotrans( transa ) ) ? n : m ; + gtint_t leny = ( testinghelpers::chknotrans( transa ) ) ? m : n ; - //---------------------------------------------------------- - // Initialize matrics with random integer numbers. - //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( 1, 5, storage, 'n', m, n, lda ); - std::vector x = testinghelpers::get_random_vector( 1, 3, lenx, incx ); - std::vector y = testinghelpers::get_random_vector( 1, 3, leny, incy ); + dim_t size_x = testinghelpers::buff_dim(lenx, incx) * sizeof(T); + dim_t size_y = testinghelpers::buff_dim(leny, incy) * sizeof(T); + testinghelpers::ProtectedBuffer x_buf(size_x, false, is_memory_test); + testinghelpers::ProtectedBuffer y_buf(size_y, false, is_memory_test); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + testinghelpers::datagenerators::randomgenerators( 1, 3, lenx, incx, (T*)(x_buf.greenzone_1) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( 1, 3, leny, incy, (T*)(y_buf.greenzone_1) ); + else + { + // Vector Y should not be read, only set. + testinghelpers::set_vector( leny, incy, (T*)(y_buf.greenzone_1), testinghelpers::aocl_extreme() ); + } + + T* a = (T*)(a_buf.greenzone_1); + T* x = (T*)(x_buf.greenzone_1); + T* y = (T*)(y_buf.greenzone_1); + T* y_ref = ( T* )y_ref_buffer.greenzone_1; // For y_ref, there is no greenzone_2 + + if ( is_evt_test ) + { + // Add extreme value to A matrix + dim_t ai = rand() % m; + dim_t aj = rand() % n; + testinghelpers::set_ev_mat( storage, 'n', lda, ai, aj, a_exval, a ); + + // Add extreme value to x vector + x[ (rand() % lenx) * std::abs(incx) ] = x_exval; + + // Add extreme value to y vector + y[ (rand() % leny) * std::abs(incy) ] = y_exval; + } + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); - // Create a copy of c so that we can check reference results. - std::vector y_ref(y); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - gemv( storage, trnsa, conjx, m, n, &alpha, a.data(), lda, - x.data(), incx, &beta, y.data(), incy ); + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + gemv( storage, transa, conjx, m, n, &alpha, a, lda, x, incx, &beta, + y, incy ); + + if ( is_memory_test ) + { + memcpy((a_buf.greenzone_2), (a_buf.greenzone_1), size_a); + memcpy((x_buf.greenzone_2), (x_buf.greenzone_1), size_x); + memcpy((y_buf.greenzone_2), y_ref, size_y); + + gemv( storage, transa, conjx, m, n, &alpha, + (T*)(a_buf.greenzone_2), lda, + (T*)(x_buf.greenzone_2), incx, + &beta, + (T*)(y_buf.greenzone_2), incy ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_gemv( storage, trnsa, conjx, m, n, alpha, a.data(), - lda, x.data(), incx, beta, y_ref.data(), incy ); + testinghelpers::ref_gemv( storage, transa, conjx, m, n, alpha, a, + lda, x, incx, beta, y_ref, incy ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( leny, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", leny, y, y_ref, incy, thresh, is_evt_test ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class gemvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + gtint_t incx = std::get<7>(str.param); + gtint_t incy = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + bool is_memory_test = std::get<10>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; + +template +class gemvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + gtint_t incx = std::get<7>(str.param); + gtint_t incy = std::get<8>(str.param); + T a_exval = std::get<9>(str.param); + T x_exval = std::get<10>(str.param); + T y_exval = std::get<11>(str.param); + gtint_t lda_inc = std::get<12>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name = str_name + "_a_exval_" + testinghelpers::get_value_string(a_exval); + str_name = str_name + "_x_exval_" + testinghelpers::get_value_string(x_exval); + str_name = str_name + "_y_exval_" + testinghelpers::get_value_string(y_exval); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_evt.cpp b/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_evt.cpp new file mode 100644 index 0000000000..f34b7331ea --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_evt.cpp @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = dcomplex; +using RT = testinghelpers::type_info::real_type; +static RT AOCL_NaN = std::numeric_limits::quiet_NaN(); +static RT AOCL_Inf = std::numeric_limits::infinity(); + +class zgemvEVT : + public ::testing::TestWithParam> {}; // lda_inc + +TEST_P( zgemvEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // exception value for a: + T a_exval = std::get<9>(GetParam()); + // exception value for x: + T x_exval = std::get<10>(GetParam()); + // exception value for y: + T y_exval = std::get<11>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<12>(GetParam()); + + bool is_memory_test = false; + bool is_evt_test = true; + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test, is_evt_test, a_exval, x_exval, y_exval ); +} + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_unitStride, + zgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // alpha + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // a_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // x_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + matrix_vector_nonUnitStride, + zgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // alpha + ::testing::Values(T{ 0.0, 0.0}, + T{ 1.0, 1.0}, + T{ 2.1, -1.2}, + T{-1.0, 0.0}, + T{ 1.0, 0.0}), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // a_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // x_exval + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_unitStride, + zgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // m + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15)), // n + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // alpha + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // beta + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(T{0.0, 0.0}), // a_exval + ::testing::Values(T{0.0, 0.0}), // x_exval + ::testing::Values(T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(0)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + alpha_beta_nonUnitStride, + zgemvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(55)), // m + ::testing::Values(gtint_t(55)), // n + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // alpha + ::testing::Values(T{AOCL_NaN, AOCL_NaN}, + T{AOCL_Inf, -AOCL_Inf}, + T{AOCL_NaN, AOCL_Inf}, + T{2.1, AOCL_Inf}, + T{AOCL_Inf, -1.2}, + T{AOCL_Inf, 0.0}, + T{0.0, AOCL_Inf}, + T{0.0, 0.0}), // beta + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(5)), // stride size for y + ::testing::Values(T{0.0, 0.0}), // a_exval + ::testing::Values(T{0.0, 0.0}), // x_exval + ::testing::Values(T{0.0, 0.0}), // y_exval + ::testing::Values(gtint_t(7)) // increment to the leading dim of a + ), + ::gemvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_generic.cpp new file mode 100644 index 0000000000..35d7089e36 --- /dev/null +++ b/gtestsuite/testsuite/level2/gemv/zgemv/zgemv_generic.cpp @@ -0,0 +1,246 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/gemv/test_gemv.h" + +using T = dcomplex; + +class zgemvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +TEST_P( zgemvGeneric, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // specifies beta value + T beta = std::get<6>(GetParam()); + // stride size for x: + gtint_t incx = std::get<7>(GetParam()); + // stride size for y: + gtint_t incy = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if(( transa == 'n' ) || ( transa == 'N' )) + thresh = (3*n+1)*testinghelpers::getEpsilon(); + else + thresh = (3*m+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, is_memory_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + zgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(20), 1), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + zgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // m + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(-1)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(-1)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +#if 1 +INSTANTIATE_TEST_SUITE_P( + Blackbox_Large, + zgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(2127)), // m + ::testing::Values(gtint_t(2127)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeM, + zgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(5099)), // m + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Blackbox_LargeN, + zgemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n'), // conjx + ::testing::Values(gtint_t(1), gtint_t(2), gtint_t(17), + gtint_t(173)), // m + ::testing::Values(gtint_t(5099)), // n + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // alpha + ::testing::Values(T{0.0, 0.0}, T{1.0, 0.0}, T{-1.0, 0.0}, + T{1.1, -2.0} ), // beta + ::testing::Values(gtint_t(1), gtint_t(211)), // stride size for x + ::testing::Values(gtint_t(1), gtint_t(11)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(57)), // increment to the leading dim of a + ::testing::Values(false, true) // is_memory_test + ), + ::gemvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp deleted file mode 100644 index 8c27717111..0000000000 --- a/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemv.h" - -class zgemvTest : - public ::testing::TestWithParam> {}; - -TEST_P(zgemvTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // specifies beta value - T beta = std::get<6>(GetParam()); - // stride size for x: - gtint_t incx = std::get<7>(GetParam()); - // stride size for y: - gtint_t incy = std::get<8>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); -} - -class zgemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char transa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - dcomplex beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zgemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zgemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zgemv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + transa+conjx; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - zgemvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','c','t'), // transa - ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, -2.0}), // alpha - ::testing::Values(dcomplex{-1.0, 1.0}), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)) // increment to the leading dim of a - ), - ::zgemvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/ger/IIT_ERS/ger_IIT_ERS.cpp b/gtestsuite/testsuite/level2/ger/IIT_ERS/ger_IIT_ERS.cpp new file mode 100644 index 0000000000..eff1d0fa79 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/IIT_ERS/ger_IIT_ERS.cpp @@ -0,0 +1,786 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + + +template +class ger_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; + +TYPED_TEST_SUITE(ger_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) + +// Invalid value of STORAGE +TYPED_TEST(ger_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + gtint_t invalid_m = -1; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + ger( 'x', CONJ, CONJ, M, N, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( 'x', CONJ, CONJ, invalid_m, N, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/** + * BLAS Invalid Input Tests(IIT): + * + * Following conditions are considered as Invalid Inputs for GER: + * 1. m < 0 + * 2. n < 0 + * 3. incx = 0 + * 4. incy = 0 + * 5. lda < max(1, m) + */ +// m < 0, with unit stride +TYPED_TEST(ger_IIT_ERS, m_lt_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_m = -1; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, invalid_m, N, nullptr, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of m. + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +// m < 0, with non-unit stride +TYPED_TEST(ger_IIT_ERS, m_lt_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_m = -1; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, invalid_m, N, nullptr, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of m. + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, x.data(), inc, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +// n < 0, with unit stride +TYPED_TEST(ger_IIT_ERS, n_lt_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, invalid_n, nullptr, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif +} + +// n < 0, with non-unit stride +TYPED_TEST(ger_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = -1; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, invalid_n, nullptr, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, x.data(), inc, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 2 ); +#endif +} + +// incx = 0, with unit incy +TYPED_TEST(ger_IIT_ERS, incx_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_incx = 0; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, invalid_incx, + nullptr, unit_inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, invalid_incx, + nullptr, unit_inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of incx. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), invalid_incx, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif +} + +// incx = 0, with non-unit incy +TYPED_TEST(ger_IIT_ERS, incx_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_incx = 0; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, invalid_incx, + nullptr, inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, invalid_incx, + nullptr, inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of incx. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), invalid_incx, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif +} + +// incy = 0, with unit incx +TYPED_TEST(ger_IIT_ERS, incy_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_incy = 0; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, unit_inc, + nullptr, invalid_incy, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, unit_inc, + nullptr, invalid_incy, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of incy. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), unit_inc, + y.data(), invalid_incy, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif +} + +// incy = 0, with non-unit incx +TYPED_TEST(ger_IIT_ERS, incy_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_incy = 0; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, inc, + nullptr, invalid_incy, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, inc, + nullptr, invalid_incy, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of incy. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), inc, + y.data(), invalid_incy, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif +} + +// lda < max(1, M), with unit stride +TYPED_TEST(ger_IIT_ERS, lda_lt_max_1_m_unitStride) +{ + using T = TypeParam; + gtint_t invalid_lda = M - 1; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, unit_inc, + nullptr, unit_inc, nullptr, invalid_lda ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, invalid_lda ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of lda. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), invalid_lda ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif +} + +// lda < max(1, M), with non-unit stride +TYPED_TEST(ger_IIT_ERS, lda_lt_max_1_m_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_lda = LDA - 1; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, N, nullptr, nullptr, inc, + nullptr, inc, nullptr, invalid_lda ); +#else + ger( STORAGE, CONJ, CONJ, M, N, &alpha, nullptr, inc, + nullptr, inc, nullptr, invalid_lda ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( STORAGE, CONJ, CONJ, M, N, &alpha, x.data(), inc, + y.data(), inc, a.data(), invalid_lda ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif +} + +/** + * BLAS Early Return Scenarios(ERS): + * + * GER is expected to return early in the following cases: + * 1. m == 0 + * 2. n == 0 + * 3. alpha == 0 + */ +// m == 0, with unit stride +TYPED_TEST(ger_IIT_ERS, m_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_m = 0; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, invalid_m, N, nullptr, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of m. + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// m == 0, with non-unit stride +TYPED_TEST(ger_IIT_ERS, m_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_m = 0; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, invalid_m, N, nullptr, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of m. + ger( STORAGE, CONJ, CONJ, invalid_m, N, &alpha, x.data(), inc, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// n == 0, with unit stride +TYPED_TEST(ger_IIT_ERS, n_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t unit_inc = 1; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, invalid_n, nullptr, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// n == 0, with non-unit stride +TYPED_TEST(ger_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t invalid_n = 0; + gtint_t inc = 3; + // Using a random non-zero value of alpha. + T alpha = T{3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + ger( STORAGE, CONJ, CONJ, M, invalid_n, nullptr, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#else + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of n. + ger( STORAGE, CONJ, CONJ, M, invalid_n, &alpha, x.data(), inc, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// alpha == 0, with unit stride +TYPED_TEST(ger_IIT_ERS, alpha_eq_zero_unitStride) +{ + using T = TypeParam; + gtint_t unit_inc = 1; + T zero_alpha = T{0}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + ger( STORAGE, CONJ, CONJ, M, N, &zero_alpha, nullptr, unit_inc, + nullptr, unit_inc, nullptr, LDA ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, unit_inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, unit_inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of alpha. + ger( STORAGE, CONJ, CONJ, M, N, &zero_alpha, x.data(), unit_inc, + y.data(), unit_inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// alpha == 0, with non-unit stride +TYPED_TEST(ger_IIT_ERS, alpha_eq_zero_nonUnitStride) +{ + using T = TypeParam; + gtint_t inc = 3; + T zero_alpha = T{0}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + ger( STORAGE, CONJ, CONJ, M, N, &zero_alpha, nullptr, inc, + nullptr, inc, nullptr, LDA ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( -2, 5, STORAGE, 'n', M, N, LDA ); + std::vector x = testinghelpers::get_random_vector( -3, 3, M, inc ); + std::vector y = testinghelpers::get_random_vector( -3, 3, N, inc ); + + // Create a copy of a matrix so that we can check reference results. + std::vector a_ref(a); + + // Invoking GER with an invalid value of alpha. + ger( STORAGE, CONJ, CONJ, M, N, &zero_alpha, x.data(), inc, + y.data(), inc, a.data(), LDA ); + + // Computing bitwise difference. + computediff( "A", STORAGE, M, N, a.data(), a_ref.data(), LDA ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif diff --git a/gtestsuite/testsuite/level2/ger/cger/cger_evt.cpp b/gtestsuite/testsuite/level2/ger/cger/cger_evt.cpp new file mode 100644 index 0000000000..8a53195e9c --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/cger/cger_evt.cpp @@ -0,0 +1,208 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +using T = scomplex; +using RT = testinghelpers::type_info::real_type; +static RT NaN = std::numeric_limits::quiet_NaN(); +static RT Inf = std::numeric_limits::infinity(); + +class DISABLED_cgerEVT : + public ::testing::TestWithParam> {}; // y_exval + +TEST_P( DISABLED_cgerEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment: + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + // ai: + gtint_t ai = std::get<9>(GetParam()); + // aj: + gtint_t aj = std::get<10>(GetParam()); + // a_exval: + T a_exval = std::get<11>(GetParam()); + // xi: + gtint_t xi = std::get<12>(GetParam()); + // x_exval: + T x_exval = std::get<13>(GetParam()); + // yi: + gtint_t yi = std::get<14>(GetParam()); + // y_exval: + T y_exval = std::get<15>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 7*testinghelpers::getEpsilon(); + + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, + ai, aj, a_exval, xi, x_exval, yi, y_exval, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitStride, + DISABLED_cgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0, 1.0}, T{2.3, -1.2}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(0) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ) + ), + ::gerEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitStrides, + DISABLED_cgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0, 1.0}, T{2.3, -1.2}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(5) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(7) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ) + ), + ::gerEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/cger/cger_generic.cpp b/gtestsuite/testsuite/level2/ger/cger/cger_generic.cpp new file mode 100644 index 0000000000..70579d1b88 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/cger/cger_generic.cpp @@ -0,0 +1,320 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +class cgerGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( cgerGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; + double adj = 3.0; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // m + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // n + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // alpha: value of scalar + ::testing::Values( scomplex{-1.0, 4.0}, scomplex{1.0, 1.0}, scomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjXY, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( scomplex{-1.0, 4.0}, scomplex{1.0, 1.0}, scomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrements, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( scomplex{-1.0, 4.0}, scomplex{1.0, 1.0}, scomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +// @note negativeIncrement tests are resulting in Segmentation Faults when +// BLIS_TYPED interface is being tested. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrements, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'n', 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( scomplex{-1.0, 4.0}, scomplex{1.0, 1.0}, scomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(-2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(-3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + scalarCombinations, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // m + ::testing::Values( gtint_t(35) ), + // n + ::testing::Values( gtint_t(40) ), + // alpha: value of scalar + ::testing::Values( scomplex{-100.0, 200.0}, scomplex{200.0, 100.0}, scomplex{-175.0, -143.0},scomplex{187.0, -275.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +//large values of m and n +INSTANTIATE_TEST_SUITE_P( + largeSize, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // m + ::testing::Values( gtint_t(3500) ), + // n + ::testing::Values( gtint_t(4000) ), + // alpha: value of scalar + ::testing::Values( scomplex{-10.0, 8.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2), gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3), gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +//Stride greater than m and n +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + cgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // conjy: use n for no_conjugate and c for conjugate. + ::testing::Values( 'c' ), + // m + ::testing::Values( gtint_t(3) ), + // n + ::testing::Values( gtint_t(4) ), + // alpha: value of scalar + ::testing::Values( scomplex{-10.0, 8.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(15) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(18) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(20) ) + ), + ::gerGenericPrint() + ); + diff --git a/gtestsuite/testsuite/level2/ger/cger_generic.cpp b/gtestsuite/testsuite/level2/ger/cger_generic.cpp deleted file mode 100644 index 024ac6d4da..0000000000 --- a/gtestsuite/testsuite/level2/ger/cger_generic.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_ger.h" - -class cgerTest : - public ::testing::TestWithParam> {}; - -TEST_P(cgerTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<1>(GetParam()); - // denotes whether vector y is n,c - char conjy = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // stride size for y: - gtint_t incy = std::get<7>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); -} - -class cgerTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cger_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cger"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cger"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + conjx+conjy; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - cgerTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // conjx - ::testing::Values('n','c'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, -2.0}), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a - ), - ::cgerTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/ger/dger/dger_evt.cpp b/gtestsuite/testsuite/level2/ger/dger/dger_evt.cpp new file mode 100644 index 0000000000..1e60ee43e1 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/dger/dger_evt.cpp @@ -0,0 +1,207 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +using T = double; +static T NaN = std::numeric_limits::quiet_NaN(); +static T Inf = std::numeric_limits::infinity(); + +class DISABLED_dgerEVT : + public ::testing::TestWithParam> {}; // y_exval + +TEST_P( DISABLED_dgerEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment: + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + // ai: + gtint_t ai = std::get<9>(GetParam()); + // aj: + gtint_t aj = std::get<10>(GetParam()); + // a_exval: + T a_exval = std::get<11>(GetParam()); + // xi: + gtint_t xi = std::get<12>(GetParam()); + // x_exval: + T x_exval = std::get<13>(GetParam()); + // yi: + gtint_t yi = std::get<14>(GetParam()); + // y_exval: + T y_exval = std::get<15>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); + + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, + ai, aj, a_exval, xi, x_exval, yi, y_exval, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitStride, + DISABLED_dgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0}, T{2.3}, NaN, Inf, -Inf ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(0) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ) + ), + ::gerEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitStride, + DISABLED_dgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0}, T{2.3}, NaN, Inf, -Inf ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(5) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(7) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ) + ), + ::gerEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/dger/dger_generic.cpp b/gtestsuite/testsuite/level2/ger/dger/dger_generic.cpp new file mode 100644 index 0000000000..514a1fd905 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/dger/dger_generic.cpp @@ -0,0 +1,312 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +class dgerGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( dgerGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // n + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // alpha: value of scalar + ::testing::Values( double(-4.1), double(1.0), double(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjXY, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c', 'r' ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'c' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( double(-4.1), double(1.0), double(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrements, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( double(-4.1), double(1.0), double(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +// @note negativeIncrement tests are resulting in Segmentation Faults when +// BLIS_TYPED interface is being tested. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrements, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( double(-4.1), double(1.0), double(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(-2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(-3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + scalarCombinations, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3) ), + // n + ::testing::Values( gtint_t(3) ), + // alpha: value of scalar + ::testing::Values( double(-500.1), double(1000.0), double(48.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +//large size for m and n +INSTANTIATE_TEST_SUITE_P( + largeSize, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3000) ), + // n + ::testing::Values( gtint_t(2500) ), + // alpha: value of scalar + ::testing::Values( double(5.1) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3),gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(4),gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +//incx and incy are greater than m and n. +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + dgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3) ), + // n + ::testing::Values( gtint_t(2) ), + // alpha: value of scalar + ::testing::Values( double(5.1) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(10) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(15) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(7) ) + ), + ::gerGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/dger_generic.cpp b/gtestsuite/testsuite/level2/ger/dger_generic.cpp deleted file mode 100644 index 1fd5efa4f2..0000000000 --- a/gtestsuite/testsuite/level2/ger/dger_generic.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_ger.h" - -class dgerTest : - public ::testing::TestWithParam> {}; - -TEST_P(dgerTest, RandomData) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<1>(GetParam()); - // denotes whether vector y is n,c - char conjy = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // stride size for y: - gtint_t incy = std::get<7>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); -} - -class dgerTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dger_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dger"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dger"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + conjx+conjy; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - dgerTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a - ), - ::dgerTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/ger/ger.h b/gtestsuite/testsuite/level2/ger/ger.h index c6747f6c7a..63348a03b0 100644 --- a/gtestsuite/testsuite/level2/ger/ger.h +++ b/gtestsuite/testsuite/level2/ger/ger.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -80,6 +81,30 @@ static void ger_( char conjy, gtint_t m, gtint_t n, T* alpha, throw std::runtime_error("Error in testsuite/level2/ger.h: Invalid typename in ger_()."); } +template +static void ger_blis_impl( char conjy, gtint_t m, gtint_t n, T* alpha, + T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) +{ + if constexpr (std::is_same::value) + sger_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else if constexpr (std::is_same::value) + dger_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else if constexpr (std::is_same::value) { + if( testinghelpers::chkconj( conjy ) ) + cgerc_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else + cgeru_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + } + else if constexpr (std::is_same::value) { + if( testinghelpers::chkconj( conjy ) ) + zgerc_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else + zgeru_blis_impl( &m, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + } + else + throw std::runtime_error("Error in testsuite/level2/ger.h: Invalid typename in ger_blis_impl()."); +} + template static void cblas_ger( char storage, char conjy, gtint_t m, gtint_t n, T* alpha, T* xp, gtint_t incx,T* yp, gtint_t incy, T* ap, gtint_t lda ) @@ -143,11 +168,54 @@ template static void ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, T* alpha, T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + conjx = static_cast(std::toupper(static_cast(conjx))); + conjy = static_cast(std::toupper(static_cast(conjy))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char conjx_cpy = conjx; + char conjy_cpy = conjy; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + gtint_t lda_cpy = lda; + + // Create copy of input arrays so we can check that they are not altered. + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( m, incx ); + if (xp && size_xp > 0) + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } + T* yp_cpy = nullptr; + gtint_t size_yp; + size_yp = testinghelpers::buff_dim( n, incy ); + if (yp && size_yp > 0) + { + yp_cpy = new T[size_yp]; + memcpy( yp_cpy, yp, size_yp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) ger_( conjy, m, n, alpha, xp, incx, yp, incy, ap, lda ); else throw std::runtime_error("Error in testsuite/level2/ger.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + ger_blis_impl( conjy, m, n, alpha, xp, incx, yp, incy, ap, lda ); + else + throw std::runtime_error("Error in testsuite/level2/ger.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_ger( storage, conjy, m, n, alpha, xp, incx, yp, incy, ap, lda ); #elif TEST_BLIS_TYPED @@ -155,4 +223,36 @@ static void ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/ger.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "conjx", conjx, conjx_cpy ); + computediff( "conjy", conjy, conjy_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (xp && size_xp > 0) + { + computediff( "x", m, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } + + if (yp && size_yp > 0) + { + computediff( "y", n, yp, yp_cpy, incy, true ); + delete[] yp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/ger/sger/sger_evt.cpp b/gtestsuite/testsuite/level2/ger/sger/sger_evt.cpp new file mode 100644 index 0000000000..40fb0c3c22 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/sger/sger_evt.cpp @@ -0,0 +1,207 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +using T = float; +static T NaN = std::numeric_limits::quiet_NaN(); +static T Inf = std::numeric_limits::infinity(); + +class DISABLED_sgerEVT : + public ::testing::TestWithParam> {}; // y_exval + +TEST_P( DISABLED_sgerEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment: + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + // ai: + gtint_t ai = std::get<9>(GetParam()); + // aj: + gtint_t aj = std::get<10>(GetParam()); + // a_exval: + T a_exval = std::get<11>(GetParam()); + // xi: + gtint_t xi = std::get<12>(GetParam()); + // x_exval: + T x_exval = std::get<13>(GetParam()); + // yi: + gtint_t yi = std::get<14>(GetParam()); + // y_exval: + T y_exval = std::get<15>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); + + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, + ai, aj, a_exval, xi, x_exval, yi, y_exval, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitStride, + DISABLED_sgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0}, T{2.3}, NaN, Inf, -Inf ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(0) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ) + ), + ::gerEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitStride, + DISABLED_sgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0}, T{2.3}, NaN, Inf, -Inf ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(5) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(7) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0}, NaN, Inf, -Inf ) + ), + ::gerEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/sger/sger_generic.cpp b/gtestsuite/testsuite/level2/ger/sger/sger_generic.cpp new file mode 100644 index 0000000000..2d9283e3fb --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/sger/sger_generic.cpp @@ -0,0 +1,315 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +class sgerGeneric : + public ::testing::TestWithParam> {}; // lda_inc + +TEST_P( sgerGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // n + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // alpha: value of scalar + ::testing::Values( float(-4.1), float(1.0), float(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjXY, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'c' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( float(-4.1), float(1.0), float(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrements, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( float(-4.1), float(1.0), float(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +// @note negativeIncrement tests are resulting in Segmentation Faults when +// BLIS_TYPED interface is being tested. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrements, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( float(-4.1), float(1.0), float(2.3) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(-2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(-3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + scalarCombinations, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(5) ), + // n + ::testing::Values( gtint_t(4) ), + // alpha: value of scalar + ::testing::Values( float(-401.1), float(100.0), float(3.4)), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(1) ) + ), + ::gerGenericPrint() + ); +INSTANTIATE_TEST_SUITE_P( + largeSize, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(5000) ), + // n + ::testing::Values( gtint_t(4000) ), + // alpha: value of scalar + ::testing::Values( float(3.4) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2), gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3), gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + sgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(2) ), + // n + ::testing::Values( gtint_t(4) ), + // alpha: value of scalar + ::testing::Values( float(3.4)), + // incx: stride of x vector. + ::testing::Values( gtint_t(10) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(15) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(9) ) + ), + ::gerGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/sger_generic.cpp b/gtestsuite/testsuite/level2/ger/sger_generic.cpp deleted file mode 100644 index 37c832759d..0000000000 --- a/gtestsuite/testsuite/level2/ger/sger_generic.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_ger.h" - -class sgerTest : - public ::testing::TestWithParam> {}; - -TEST_P(sgerTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<1>(GetParam()); - // denotes whether vector y is n,c - char conjy = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // stride size for y: - gtint_t incy = std::get<7>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - - // Set the threshold for the errors: - double thresh = 4*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); -} - -class sgerTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sger_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sger"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sger"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + conjx+conjy; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - sgerTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a - ), - ::sgerTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/ger/test_ger.h b/gtestsuite/testsuite/level2/ger/test_ger.h index 3e8e7646d8..2db9f10823 100644 --- a/gtestsuite/testsuite/level2/ger/test_ger.h +++ b/gtestsuite/testsuite/level2/ger/test_ger.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -41,7 +41,6 @@ #include template - void test_ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh ) { @@ -72,5 +71,136 @@ void test_ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, a.data(), a_ref.data(), lda, thresh ); + computediff( "a", storage, m, n, a.data(), a_ref.data(), lda, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +template +void test_ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, + T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, gtint_t ai, + gtint_t aj, T a_exval, gtint_t xi, T x_exval, gtint_t yi, + T y_exval, double thresh ) +{ + // Compute the leading dimensions for matrix size calculation. + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random integer numbers. + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', m, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, m, incx ); + std::vector y = testinghelpers::get_random_vector( -3, 3, n, incy ); + + testinghelpers::set_ev_mat( storage, 'n', lda, ai, aj, a_exval, a.data() ); + // Update the value at index xi to an extreme value, x_exval. + if ( -1 < xi && xi < n ) x[xi * abs(incx)] = x_exval; + else return; + + // Update the value at index yi to an extreme value, y_exval. + if ( -1 < yi && yi < n ) y[yi * abs(incy)] = y_exval; + else return; + + // Create a copy of c so that we can check reference results. + std::vector a_ref(a); + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + ger( storage, conjx, conjy, m, n, &alpha, x.data(), incx, + y.data(), incy, a.data(), lda ); + + //---------------------------------------------------------- + // Call reference implementation. + //---------------------------------------------------------- + testinghelpers::ref_ger( storage, conjx, conjy, m, n, alpha, + x.data(), incx, y.data(), incy, a_ref.data(), lda ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "A", storage, m, n, a.data(), a_ref.data(), lda, thresh, true ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class gerGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char conjx = std::get<1>(str.param); + char conjy = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t incy = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; + +template +class gerEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char conjx = std::get<1>(str.param); + char conjy = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t incy = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ai = std::get<9>(str.param); + gtint_t aj = std::get<10>(str.param); + T a_exval = std::get<11>(str.param); + gtint_t xi = std::get<12>(str.param); + T x_exval = std::get<13>(str.param); + gtint_t yi = std::get<14>(str.param); + T y_exval = std::get<15>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name = str_name + "_ai" + std::to_string(ai); + str_name = str_name + "_aj" + std::to_string(aj); + str_name = str_name + "_a_exval_" + testinghelpers::get_value_string(a_exval); + str_name = str_name + "_xi" + std::to_string(xi); + str_name = str_name + "_x_exval_" + testinghelpers::get_value_string(x_exval); + str_name = str_name + "_yi" + std::to_string(yi); + str_name = str_name + "_y_exval_" + testinghelpers::get_value_string(y_exval); + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/ger/zger/zger_evt.cpp b/gtestsuite/testsuite/level2/ger/zger/zger_evt.cpp new file mode 100644 index 0000000000..2ea789ae34 --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/zger/zger_evt.cpp @@ -0,0 +1,208 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +using T = dcomplex; +using RT = testinghelpers::type_info::real_type; +static RT NaN = std::numeric_limits::quiet_NaN(); +static RT Inf = std::numeric_limits::infinity(); + +class DISABLED_zgerEVT : + public ::testing::TestWithParam> {}; // y_exval + +TEST_P( DISABLED_zgerEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment: + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + // ai: + gtint_t ai = std::get<9>(GetParam()); + // aj: + gtint_t aj = std::get<10>(GetParam()); + // a_exval: + T a_exval = std::get<11>(GetParam()); + // xi: + gtint_t xi = std::get<12>(GetParam()); + // x_exval: + T x_exval = std::get<13>(GetParam()); + // yi: + gtint_t yi = std::get<14>(GetParam()); + // y_exval: + T y_exval = std::get<15>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 7*testinghelpers::getEpsilon(); + + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, + ai, aj, a_exval, xi, x_exval, yi, y_exval, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitStride, + DISABLED_zgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0, 1.0}, T{2.3, -1.2}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(0) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ) + ), + ::gerEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitStride, + DISABLED_zgerEVT, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(55) ), + // n + ::testing::Values( gtint_t(33) ), + // alpha: value of scalar + ::testing::Values( T{1.0, 1.0}, T{2.3, -1.2}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(5) ), + // inc_lda: increment to the leading dim of a. + ::testing::Values( gtint_t(7) ), + // ai: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // aj: index of extreme value for a. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // a_exval: extreme value for a. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // xi: index of extreme value for x. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // x_exval: extreme value for x. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ), + // yi: index of extreme value for y. + ::testing::Values( gtint_t(0), gtint_t(7) ), + // y_exval: extreme value for y. + ::testing::Values( T{0.0, 0.0}, T{NaN, NaN}, T{NaN, Inf}, T{Inf, -Inf} ) + ), + ::gerEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/zger/zger_generic.cpp b/gtestsuite/testsuite/level2/ger/zger/zger_generic.cpp new file mode 100644 index 0000000000..7bcd74b8dd --- /dev/null +++ b/gtestsuite/testsuite/level2/ger/zger/zger_generic.cpp @@ -0,0 +1,315 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/ger/test_ger.h" + +class zgerGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( zgerGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether vector x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vector y is n,c + char conjy = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // stride size for y: + gtint_t incy = std::get<7>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment is non-negative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite ger.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 7*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // n + ::testing::Range( gtint_t(10), gtint_t(101), 10 ), + // alpha: value of scalar + ::testing::Values( dcomplex{-1.0, 4.0}, dcomplex{1.0, 1.0}, dcomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +#ifdef TEST_BLIS_TYPED +// Test when conjugate of x is used as an argument. This option is BLIS-api specific. +// Only test very few cases as sanity check since conj(x) = x for real types. +// We can modify the values using implementantion details. +INSTANTIATE_TEST_SUITE_P( + conjXY, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n', 'c' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n', 'c' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( dcomplex{-1.0, 4.0}, dcomplex{1.0, 1.0}, dcomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrements, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( dcomplex{-1.0, 4.0}, dcomplex{1.0, 1.0}, dcomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); + +// @note negativeIncrement tests are resulting in Segmentation Faults when +// BLIS_TYPED interface is being tested. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrements, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // n + ::testing::Values( gtint_t(3), gtint_t(30), gtint_t(112) ), + // alpha: value of scalar + ::testing::Values( dcomplex{-1.0, 4.0}, dcomplex{1.0, 1.0}, dcomplex{3.0, -2.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(-2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(-3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +#endif + +INSTANTIATE_TEST_SUITE_P( + scalarCombinations, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(2) ), + // n + ::testing::Values( gtint_t(3) ), + // alpha: value of scalar + ::testing::Values( dcomplex{-102.0, 404.0}, dcomplex{172.0, 138.0}, dcomplex{303.0, -267.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(2) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(3) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +INSTANTIATE_TEST_SUITE_P( + largeSize, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(1111) ), + // n + ::testing::Values( gtint_t(3333) ), + // alpha: value of scalar + ::testing::Values( dcomplex{2.0, 4.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3), gtint_t(1) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(4), gtint_t(1) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(0), gtint_t(3) ) + ), + ::gerGenericPrint() + ); +INSTANTIATE_TEST_SUITE_P( + strideGreaterThanSize, + zgerGeneric, + ::testing::Combine( + // storage scheme: row/col-stored matrix + ::testing::Values( 'c' + // row-stored tests are disabled for BLAS since BLAS only supports col-storage scheme. +#ifndef TEST_BLAS_LIKE + , 'r' +#endif + ), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // conjy: uses n (no_conjugate) since it is real. + ::testing::Values( 'n' ), + // m + ::testing::Values( gtint_t(1) ), + // n + ::testing::Values( gtint_t(3) ), + // alpha: value of scalar + ::testing::Values( dcomplex{2.0, 4.0} ), + // incx: stride of x vector. + ::testing::Values( gtint_t(11) ), + // incy: stride of y vector. + ::testing::Values( gtint_t(22) ), + // inc_lda: increment to the leading dim of a + ::testing::Values( gtint_t(9) ) + ), + ::gerGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/ger/zger_generic.cpp b/gtestsuite/testsuite/level2/ger/zger_generic.cpp deleted file mode 100644 index 5847842c30..0000000000 --- a/gtestsuite/testsuite/level2/ger/zger_generic.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_ger.h" - -class zgerTest : - public ::testing::TestWithParam> {}; - -TEST_P(zgerTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether vector x is n,c - char conjx = std::get<1>(GetParam()); - // denotes whether vector y is n,c - char conjy = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // stride size for y: - gtint_t incy = std::get<7>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - - // Set the threshold for the errors: - double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); -} - -class zgerTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char conjx = std::get<1>(str.param); - char conjy = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zger_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zger"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zger"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + conjx+conjy; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - zgerTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // conjx - ::testing::Values('n','c'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, -2.0}), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a - ), - ::zgerTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp b/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp index ed4b726817..01c876c888 100644 --- a/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp +++ b/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_hemv.h" -class chemvTest : +class chemvGeneric : public ::testing::TestWithParam> {}; -TEST_P(chemvTest, RandomData) +TEST_P( chemvGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -78,7 +78,19 @@ TEST_P(chemvTest, RandomData) gtint_t lda_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite hemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,64 +98,56 @@ TEST_P(chemvTest, RandomData) test_hemv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } -class chemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conja = std::get<2>(str.param); - char conjx = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - scomplex beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "chemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_chemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_chemv"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + chemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conja+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_a" + beta_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // alpha + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::hemvGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - chemvTest, + BlackboxMedium, + chemvGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, -2.0}), // alpha - ::testing::Values(scomplex{2.0, -1.0}), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // alpha + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::chemvTestPrint() + ::hemvGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/hemv/hemv.h b/gtestsuite/testsuite/level2/hemv/hemv.h index 90086336a7..1c29845965 100644 --- a/gtestsuite/testsuite/level2/hemv/hemv.h +++ b/gtestsuite/testsuite/level2/hemv/hemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -70,6 +71,18 @@ static void hemv_( char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, throw std::runtime_error("Error in testsuite/level2/hemv.h: Invalid typename in hemv_()."); } +template +static void hemv_blis_impl( char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, + T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) +{ + if constexpr (std::is_same::value) + chemv_blis_impl( &uploa, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else if constexpr (std::is_same::value) + zhemv_blis_impl( &uploa, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else + throw std::runtime_error("Error in testsuite/level2/hemv.h: Invalid typename in hemv_blis_impl()."); +} + template static void cblas_hemv( char storage, char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) @@ -123,11 +136,54 @@ static void hemv( char storage, char uploa, char conja, char conjx, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conja = static_cast(std::toupper(static_cast(conja))); + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conja_cpy = conja; + char conjx_cpy = conjx; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + T* beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, 'n', n, n, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) hemv_( uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); else throw std::runtime_error("Error in testsuite/level2/hemv.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + hemv_blis_impl( uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); + else + throw std::runtime_error("Error in testsuite/level2/hemv.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_hemv( storage, uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); #elif TEST_BLIS_TYPED @@ -135,4 +191,37 @@ static void hemv( char storage, char uploa, char conja, char conjx, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/hemv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conja", conja, conja_cpy ); + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, n, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/hemv/test_hemv.h b/gtestsuite/testsuite/level2/hemv/test_hemv.h index a7243cbd2e..125f9ca1d5 100644 --- a/gtestsuite/testsuite/level2/hemv/test_hemv.h +++ b/gtestsuite/testsuite/level2/hemv/test_hemv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -51,13 +51,20 @@ void test_hemv( char storage, char uploa, char conja, char conjx, gtint_t n, // Initialize matrics with random integer numbers. //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); - std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); - std::vector y = testinghelpers::get_random_vector( -3, 3, n, incy ); - testinghelpers::make_herm( storage, uploa, n, a.data(), lda ); testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); - // Create a copy of c so that we can check reference results. + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y( testinghelpers::buff_dim(n, incy) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 3, n, incy, y.data() ); + else + { + // Vector Y should not be read, only set. + testinghelpers::set_vector( n, incy, y.data(), testinghelpers::aocl_extreme() ); + } + + // Create a copy of y so that we can check reference results. std::vector y_ref(y); //---------------------------------------------------------- // Call BLIS function @@ -74,5 +81,43 @@ void test_hemv( char storage, char uploa, char conja, char conjx, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class hemvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conja = std::get<2>(str.param); + char conjx = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + gtint_t incx = std::get<7>(str.param); + gtint_t incy = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp b/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp index 81ee763b24..0fbba22554 100644 --- a/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp +++ b/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_hemv.h" -class zhemvTest : +class zhemvGeneric : public ::testing::TestWithParam> {}; -TEST_P(zhemvTest, RandomData) +TEST_P( zhemvGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -78,7 +78,24 @@ TEST_P(zhemvTest, RandomData) gtint_t lda_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 8*std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite hemv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment applied for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 2.4; +#endif + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = adj*(3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,64 +103,56 @@ TEST_P(zhemvTest, RandomData) test_hemv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } -class zhemvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conja = std::get<2>(str.param); - char conjx = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - dcomplex beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zhemv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zhemv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zhemv"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + zhemvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conja+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_a" + beta_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // alpha + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::hemvGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zhemvTest, + BlackboxMedium, + zhemvGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, -2.0}), // alpha - ::testing::Values(dcomplex{2.0, -1.0}), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // alpha + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::zhemvTestPrint() + ::hemvGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/her/cher_generic.cpp b/gtestsuite/testsuite/level2/her/cher_generic.cpp index 8be6c2ed49..b309baa058 100644 --- a/gtestsuite/testsuite/level2/her/cher_generic.cpp +++ b/gtestsuite/testsuite/level2/her/cher_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her.h" -class cherTest : +class cherGeneric : public ::testing::TestWithParam> {}; -TEST_P(cherTest, RandomData) +TEST_P( cherGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -69,7 +69,20 @@ TEST_P(cherTest, RandomData) gtint_t lda_inc = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite her.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 2.0; +#endif + if (n == 0 || alpha == 0.0f) + thresh = 0.0; + else + thresh = adj*3*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -77,52 +90,46 @@ TEST_P(cherTest, RandomData) test_her( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } -class cherTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t n = std::get<3>(str.param); - float alpha = std::get<4>(str.param); - gtint_t incx = std::get<5>(str.param); - gtint_t ld_inc = std::get<6>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cher_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cher"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cher"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + cherGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::herGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - cherTest, + BlackboxMedium, + cherGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::cherTestPrint() + ::herGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/her/her.h b/gtestsuite/testsuite/level2/her/her.h index ea7d3008c7..6c5d1e1c9b 100644 --- a/gtestsuite/testsuite/level2/her/her.h +++ b/gtestsuite/testsuite/level2/her/her.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -61,6 +62,18 @@ static void her_( char uploa, gtint_t n, Tr* alpha, T* xp, gtint_t incx, throw std::runtime_error("Error in testsuite/level2/her.h: Invalid typename in her_()."); } +template +static void her_blis_impl( char uploa, gtint_t n, Tr* alpha, T* xp, gtint_t incx, + T* ap, gtint_t lda ) +{ + if constexpr (std::is_same::value) + cher_blis_impl( &uploa, &n, alpha, xp, &incx, ap, &lda ); + else if constexpr (std::is_same::value) + zher_blis_impl( &uploa, &n, alpha, xp, &incx, ap, &lda ); + else + throw std::runtime_error("Error in testsuite/level2/her.h: Invalid typename in her_blis_impl()."); +} + template static void cblas_her( char storage, char uploa, gtint_t n, Tr* alpha, T* xp, gtint_t incx, T* ap, gtint_t lda ) @@ -111,11 +124,43 @@ template static void her( char storage, char uploa, char conj_x, gtint_t n, Tr* alpha, T* xp, gtint_t incx, T* ap, gtint_t lda ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conj_x_cpy = conj_x; + gtint_t n_cpy = n; + Tr* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t lda_cpy = lda; + + // Create copy of input arrays so we can check that they are not altered. + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) her_( uploa, n, alpha, xp, incx, ap, lda ); else throw std::runtime_error("Error in testsuite/level2/her.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + her_blis_impl( uploa, n, alpha, xp, incx, ap, lda ); + else + throw std::runtime_error("Error in testsuite/level2/her.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_her( storage, uploa, n, alpha, xp, incx, ap, lda ); #elif TEST_BLIS_TYPED @@ -123,4 +168,28 @@ static void her( char storage, char uploa, char conj_x, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/her.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "n", n, n_cpy ); + computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "lda", lda, lda_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/her/test_her.h b/gtestsuite/testsuite/level2/her/test_her.h index db41652975..b60d2e3650 100644 --- a/gtestsuite/testsuite/level2/her/test_her.h +++ b/gtestsuite/testsuite/level2/her/test_her.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -71,5 +71,37 @@ void test_her( char storage, char uploa, char conjx, gtint_t n, Tr alpha, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, n, n, a.data(), a_ref.data(), lda, thresh ); + computediff( "A", storage, n, n, a.data(), a_ref.data(), lda, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class herGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t incx = std::get<5>(str.param); + gtint_t lda_inc = std::get<6>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/her/zher_generic.cpp b/gtestsuite/testsuite/level2/her/zher_generic.cpp index 8db149caa5..9a9eb90d9d 100644 --- a/gtestsuite/testsuite/level2/her/zher_generic.cpp +++ b/gtestsuite/testsuite/level2/her/zher_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her.h" -class zherTest : +class zherGeneric : public ::testing::TestWithParam> {}; -TEST_P(zherTest, RandomData) +TEST_P( zherGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -69,7 +69,20 @@ TEST_P(zherTest, RandomData) gtint_t lda_inc = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite her.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 2.0; +#endif + if (n == 0 || alpha == 0.0) + thresh = 0.0; + else + thresh = adj*3*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -77,52 +90,46 @@ TEST_P(zherTest, RandomData) test_her( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } -class zherTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t n = std::get<3>(str.param); - double alpha = std::get<4>(str.param); - gtint_t incx = std::get<5>(str.param); - gtint_t ld_inc = std::get<6>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zher_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zher"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zher"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + zherGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::herGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zherTest, + BlackboxMedium, + zherGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::zherTestPrint() + ::herGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/her2/cher2_generic.cpp b/gtestsuite/testsuite/level2/her2/cher2_generic.cpp index f6bbd15a06..4ffbaa1a91 100644 --- a/gtestsuite/testsuite/level2/her2/cher2_generic.cpp +++ b/gtestsuite/testsuite/level2/her2/cher2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her2.h" -class cher2Test : +class cher2Generic : public ::testing::TestWithParam> {}; -TEST_P(cher2Test, RandomData) +TEST_P( cher2Generic, API ) { using T = scomplex; //---------------------------------------------------------- @@ -75,7 +75,23 @@ TEST_P(cher2Test, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = 4*n*testinghelpers::getEpsilon(); + // Check gtestsuite her2.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 4.0; + #ifdef REF_IS_MKL + adj = 6.0; + #endif +#endif + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*6*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -83,59 +99,52 @@ TEST_P(cher2Test, RandomData) test_her2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } -class cher2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - char conjy = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cher2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cher2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cher2"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + cher2Generic, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx+conjy; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::her2GenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - cher2Test, + BlackboxMedium, + cher2Generic, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, -2.0}), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(scomplex{0.0, 0.0},scomplex{1.0, 0.0}, + scomplex{-1.0, 0.0},scomplex{1.0, -2.0}), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::cher2TestPrint() + ::her2GenericPrint() ); diff --git a/gtestsuite/testsuite/level2/her2/her2.h b/gtestsuite/testsuite/level2/her2/her2.h index 759b2d90d2..09b71533bc 100644 --- a/gtestsuite/testsuite/level2/her2/her2.h +++ b/gtestsuite/testsuite/level2/her2/her2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -64,6 +65,18 @@ static void her2_( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, throw std::runtime_error("Error in testsuite/level2/her2.h: Invalid typename in her2_()."); } +template +static void her2_blis_impl( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, + T* yp, gtint_t incy, T* ap, gtint_t lda ) +{ + if constexpr (std::is_same::value) + cher2_blis_impl( &uploa, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else if constexpr (std::is_same::value) + zher2_blis_impl( &uploa, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else + throw std::runtime_error("Error in testsuite/level2/her2.h: Invalid typename in her2_blis_impl()."); +} + template static void cblas_her2( char storage, char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) @@ -116,11 +129,53 @@ template static void her2( char storage, char uploa, char conj_x, char conj_y, gtint_t n, T* alpha, T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); + conj_y = static_cast(std::toupper(static_cast(conj_y))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conj_x_cpy = conj_x; + char conj_y_cpy = conj_y; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + gtint_t lda_cpy = lda; + + // Create copy of input arrays so we can check that they are not altered. + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } + T* yp_cpy = nullptr; + gtint_t size_yp; + size_yp = testinghelpers::buff_dim( n, incy ); + { + yp_cpy = new T[size_yp]; + memcpy( yp_cpy, yp, size_yp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) her2_( uploa, n, alpha, xp, incx, yp, incy, ap, lda ); else throw std::runtime_error("Error in testsuite/level2/her2.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + her2_blis_impl( uploa, n, alpha, xp, incx, yp, incy, ap, lda ); + else + throw std::runtime_error("Error in testsuite/level2/her2.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_her2( storage, uploa, n, alpha, xp, incx, yp, incy, ap, lda ); #elif TEST_BLIS_TYPED @@ -128,4 +183,36 @@ static void her2( char storage, char uploa, char conj_x, char conj_y, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/her2.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "conj_y", conj_y, conj_y_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } + + if (yp && size_yp > 0) + { + computediff( "y", n, yp, yp_cpy, incy, true ); + delete[] yp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/her2/test_her2.h b/gtestsuite/testsuite/level2/her2/test_her2.h index b0802d64b4..5634ed8733 100644 --- a/gtestsuite/testsuite/level2/her2/test_her2.h +++ b/gtestsuite/testsuite/level2/her2/test_her2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -74,5 +74,41 @@ void test_her2( char storage, char uploa, char conjx, char conjy, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, n, n, a.data(), a_ref.data(), lda, thresh ); + computediff( "A", storage, n, n, a.data(), a_ref.data(), lda, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class her2GenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + char conjy = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t incy = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/her2/zher2_generic.cpp b/gtestsuite/testsuite/level2/her2/zher2_generic.cpp index acd8b4465a..c8cdcc7262 100644 --- a/gtestsuite/testsuite/level2/her2/zher2_generic.cpp +++ b/gtestsuite/testsuite/level2/her2/zher2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her2.h" -class zher2Test : +class zher2Generic : public ::testing::TestWithParam> {}; -TEST_P(zher2Test, RandomData) +TEST_P( zher2Generic, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -75,7 +75,20 @@ TEST_P(zher2Test, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = 6*std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite her2.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 6.0; +#endif + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*6*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -83,59 +96,52 @@ TEST_P(zher2Test, RandomData) test_her2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } -class zher2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - char conjy = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zher2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zher2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zher2"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + zher2Generic, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx+conjy; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::her2GenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - zher2Test, + BlackboxMedium, + zher2Generic, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, -2.0}), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(dcomplex{0.0, 0.0},dcomplex{1.0, 0.0}, + dcomplex{-1.0, 0.0},dcomplex{1.0, -2.0}), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::zher2TestPrint() + ::her2GenericPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp b/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp index a62f20996d..6669c353af 100644 --- a/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp +++ b/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symv.h" -class dsymvTest : +class dsymvGeneric : public ::testing::TestWithParam> {}; -TEST_P(dsymvTest, RandomData) +TEST_P( dsymvGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -78,7 +78,26 @@ TEST_P(dsymvTest, RandomData) gtint_t lda_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); + // Check gtestsuite symv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.4; +#else + double adj = 1.7; + #ifdef REF_IS_MKL + adj = 1.4; + #endif +#endif + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = adj*(3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,62 +105,52 @@ TEST_P(dsymvTest, RandomData) test_symv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } -class dsymvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conja = std::get<2>(str.param); - char conjx = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - double beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsymv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsymv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsymv"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dsymvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conja+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : ("m" + std::to_string(int(std::abs(beta)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_a" + beta_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::symvGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dsymvTest, + BlackboxMedium, + dsymvGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0, -2.0 ), // alpha - ::testing::Values( 2.0, -1.0 ), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::dsymvTestPrint() + ::symvGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp b/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp index d83d75b7dc..6b4dd5cfba 100644 --- a/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp +++ b/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symv.h" -class ssymvTest : +class ssymvGeneric : public ::testing::TestWithParam> {}; -TEST_P(ssymvTest, RandomData) +TEST_P( ssymvGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -78,7 +78,26 @@ TEST_P(ssymvTest, RandomData) gtint_t lda_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); + // Check gtestsuite symv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 3.4; +#else + double adj = 2.0; + #ifdef REF_IS_MKL + adj = 1.4; + #endif +#endif + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = adj*(3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,62 +105,52 @@ TEST_P(ssymvTest, RandomData) test_symv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } -class ssymvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conja = std::get<2>(str.param); - char conjx = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - float beta = std::get<6>(str.param); - gtint_t incx = std::get<7>(str.param); - gtint_t incy = std::get<8>(str.param); - gtint_t ld_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssymv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssymv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssymv"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ssymvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conja+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : ("m" + std::to_string(int(std::abs(beta)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_a" + beta_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::symvGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ssymvTest, + BlackboxMedium, + ssymvGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0, -2.0 ), // alpha - ::testing::Values( 2.0, -1.0 ), // beta - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // beta + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::ssymvTestPrint() + ::symvGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/symv.h b/gtestsuite/testsuite/level2/symv/symv.h index 2d77b25de4..79f4cae790 100644 --- a/gtestsuite/testsuite/level2/symv/symv.h +++ b/gtestsuite/testsuite/level2/symv/symv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -65,6 +66,18 @@ static void symv_( char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, throw std::runtime_error("Error in testsuite/level2/symv.h: Invalid typename in symv_()."); } +template +static void symv_blis_impl( char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, + T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) +{ + if constexpr (std::is_same::value) + ssymv_blis_impl( &uploa, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else if constexpr (std::is_same::value) + dsymv_blis_impl( &uploa, &n, alpha, ap, &lda, xp, &incx, beta, yp, &incy ); + else + throw std::runtime_error("Error in testsuite/level2/symv.h: Invalid typename in symv_blis_impl()."); +} + template static void cblas_symv( char storage, char uploa, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) @@ -118,11 +131,54 @@ static void symv( char storage, char uploa, char conja, char conjx, gtint_t n, T* alpha, T* ap, gtint_t lda, T* xp, gtint_t incx, T* beta, T* yp, gtint_t incy ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conja = static_cast(std::toupper(static_cast(conja))); + conjx = static_cast(std::toupper(static_cast(conjx))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conja_cpy = conja; + char conjx_cpy = conjx; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + T* beta_cpy = beta; + gtint_t incy_cpy = incy; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, 'n', n, n, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) symv_( uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); else throw std::runtime_error("Error in testsuite/level2/symv.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + symv_blis_impl( uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); + else + throw std::runtime_error("Error in testsuite/level2/symv.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_symv( storage, uploa, n, alpha, ap, lda, xp, incx, beta, yp, incy ); #elif TEST_BLIS_TYPED @@ -130,4 +186,37 @@ static void symv( char storage, char uploa, char conja, char conjx, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/symv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conja", conja, conja_cpy ); + computediff( "conjx", conjx, conjx_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, n, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/symv/test_symv.h b/gtestsuite/testsuite/level2/symv/test_symv.h index f0df77c18b..c2adfb6767 100644 --- a/gtestsuite/testsuite/level2/symv/test_symv.h +++ b/gtestsuite/testsuite/level2/symv/test_symv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -51,13 +51,20 @@ void test_symv( char storage, char uploa, char conja, char conjx, gtint_t n, // Initialize matrics with random integer numbers. //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); - std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); - std::vector y = testinghelpers::get_random_vector( -2, 5, n, incy ); - testinghelpers::make_symm( storage, uploa, n, a.data(), lda ); testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); - // Create a copy of c so that we can check reference results. + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y( testinghelpers::buff_dim(n, incy) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -2, 5, n, incy, y.data() ); + else + { + // Vector Y should not be read, only set. + testinghelpers::set_vector( n, incy, y.data(), testinghelpers::aocl_extreme() ); + } + + // Create a copy of y so that we can check reference results. std::vector y_ref(y); //---------------------------------------------------------- // Call BLIS function @@ -74,5 +81,43 @@ void test_symv( char storage, char uploa, char conja, char conjx, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( n, y.data(), y_ref.data(), incy, thresh ); + computediff( "y", n, y.data(), y_ref.data(), incy, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class symvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conja = std::get<2>(str.param); + char conjx = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + gtint_t incx = std::get<7>(str.param); + gtint_t incy = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp b/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp index 3d755586a8..60f3090d9e 100644 --- a/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp +++ b/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr.h" -class dsyrTest : +class dsyrGeneric : public ::testing::TestWithParam> {}; -TEST_P(dsyrTest, RandomData) +TEST_P( dsyrGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -69,7 +69,14 @@ TEST_P(dsyrTest, RandomData) gtint_t lda_inc = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = 2*n*testinghelpers::getEpsilon(); + // Check gtestsuite syr.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -77,52 +84,46 @@ TEST_P(dsyrTest, RandomData) test_syr( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } -class dsyrTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t n = std::get<3>(str.param); - double alpha = std::get<4>(str.param); - gtint_t incx = std::get<5>(str.param); - gtint_t ld_inc = std::get<6>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsyr_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsyr"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsyr"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dsyrGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::syrGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dsyrTest, + BlackboxMedium, + dsyrGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::dsyrTestPrint() + ::syrGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp b/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp index 446c2f4743..59cbd9d3a2 100644 --- a/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp +++ b/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr.h" -class ssyrTest : +class ssyrGeneric : public ::testing::TestWithParam> {}; -TEST_P(ssyrTest, RandomData) +TEST_P( ssyrGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -69,7 +69,14 @@ TEST_P(ssyrTest, RandomData) gtint_t lda_inc = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = 2*n*testinghelpers::getEpsilon(); + // Check gtestsuite syr.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -77,52 +84,46 @@ TEST_P(ssyrTest, RandomData) test_syr( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } -class ssyrTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - gtint_t n = std::get<3>(str.param); - float alpha = std::get<4>(str.param); - gtint_t incx = std::get<5>(str.param); - gtint_t ld_inc = std::get<6>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssyr_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssyr"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssyr"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ssyrGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx; - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::syrGenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ssyrTest, + BlackboxMedium, + ssyrGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa ::testing::Values('n'), // conjx - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::ssyrTestPrint() + ::syrGenericPrint() ); diff --git a/gtestsuite/testsuite/level2/syr/syr.h b/gtestsuite/testsuite/level2/syr/syr.h index e16d5c5322..f824e4fe82 100644 --- a/gtestsuite/testsuite/level2/syr/syr.h +++ b/gtestsuite/testsuite/level2/syr/syr.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -63,6 +64,18 @@ static void syr_( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, throw std::runtime_error("Error in testsuite/level2/syr.h: Invalid typename in syr_()."); } +template +static void syr_blis_impl( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, + T* ap, gtint_t lda ) +{ + if constexpr (std::is_same::value) + ssyr_blis_impl( &uploa, &n, alpha, xp, &incx, ap, &lda ); + else if constexpr (std::is_same::value) + dsyr_blis_impl( &uploa, &n, alpha, xp, &incx, ap, &lda ); + else + throw std::runtime_error("Error in testsuite/level2/syr.h: Invalid typename in syr_blis_impl()."); +} + template static void cblas_syr( char storage, char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, T* ap, gtint_t lda ) @@ -113,11 +126,43 @@ template static void syr( char storage, char uploa, char conj_x, gtint_t n, T* alpha, T* xp, gtint_t incx, T* ap, gtint_t lda ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conj_x_cpy = conj_x; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t lda_cpy = lda; + + // Create copy of input arrays so we can check that they are not altered. + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) syr_( uploa, n, alpha, xp, incx, ap, lda ); else throw std::runtime_error("Error in testsuite/level2/syr.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + syr_blis_impl( uploa, n, alpha, xp, incx, ap, lda ); + else + throw std::runtime_error("Error in testsuite/level2/syr.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_syr( storage, uploa, n, alpha, xp, incx, ap, lda ); #elif TEST_BLIS_TYPED @@ -125,4 +170,28 @@ static void syr( char storage, char uploa, char conj_x, gtint_t n, T* alpha, #else throw std::runtime_error("Error in testsuite/level2/syr.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "lda", lda, lda_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/syr/test_syr.h b/gtestsuite/testsuite/level2/syr/test_syr.h index 125445fa19..1f0a3fcdfa 100644 --- a/gtestsuite/testsuite/level2/syr/test_syr.h +++ b/gtestsuite/testsuite/level2/syr/test_syr.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -71,5 +71,37 @@ void test_syr( char storage, char uploa, char conjx, gtint_t n, T alpha, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, n, n, a.data(), a_ref.data(), lda, thresh ); + computediff( "A", storage, n, n, a.data(), a_ref.data(), lda, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class syrGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + T alpha = std::get<4>(str.param); + gtint_t incx = std::get<5>(str.param); + gtint_t lda_inc = std::get<6>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp b/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp index 2a021ea6d8..e10cea2b26 100644 --- a/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp +++ b/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2.h" -class dsyr2Test : +class dsyr2Generic : public ::testing::TestWithParam> {}; -TEST_P(dsyr2Test, RandomData) +TEST_P( dsyr2Generic, API ) { using T = double; //---------------------------------------------------------- @@ -75,7 +75,19 @@ TEST_P(dsyr2Test, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = 3*n*testinghelpers::getEpsilon(); + // Check gtestsuite syr2.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.34; +#else + double adj = 4.0; +#endif + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*6*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -83,58 +95,50 @@ TEST_P(dsyr2Test, RandomData) test_syr2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } -class dsyr2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - char conjy = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsyr2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsyr2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsyr2"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dsyr2Generic, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx+conjy; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::syr2GenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dsyr2Test, + BlackboxMedium, + dsyr2Generic, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0, -2.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::dsyr2TestPrint() + ::syr2GenericPrint() ); diff --git a/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp b/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp index 75df2d0367..b40100c307 100644 --- a/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp +++ b/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2.h" -class ssyr2Test : +class ssyr2Generic : public ::testing::TestWithParam> {}; -TEST_P(ssyr2Test, RandomData) +TEST_P( ssyr2Generic, API ) { using T = float; //---------------------------------------------------------- @@ -75,7 +75,19 @@ TEST_P(ssyr2Test, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = 3*n*testinghelpers::getEpsilon(); + // Check gtestsuite syr2.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 3.0; +#endif + if (n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*6*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -83,58 +95,50 @@ TEST_P(ssyr2Test, RandomData) test_syr2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } -class ssyr2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char conjx = std::get<2>(str.param); - char conjy = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t incy = std::get<7>(str.param); - gtint_t ld_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssyr2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssyr2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssyr2"; +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ssyr2Generic, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' #endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+conjx+conjy; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + incy_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja + ::testing::Values('n'), // conjx + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ), + ::syr2GenericPrint() + ); -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - ssyr2Test, + BlackboxMedium, + ssyr2Generic, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // uploa + ::testing::Values('n'), // conja ::testing::Values('n'), // conjx - ::testing::Values('n'), // conjy - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(1.0, -2.0), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values( 0.0, 1.0, -1.0, 2.7 ), // alpha + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(1),gtint_t(-1),gtint_t(2)), // stride size for y + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), - ::ssyr2TestPrint() + ::syr2GenericPrint() ); diff --git a/gtestsuite/testsuite/level2/syr2/syr2.h b/gtestsuite/testsuite/level2/syr2/syr2.h index dd51b5497b..1f3538d8f8 100644 --- a/gtestsuite/testsuite/level2/syr2/syr2.h +++ b/gtestsuite/testsuite/level2/syr2/syr2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -64,6 +65,18 @@ static void syr2_( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, throw std::runtime_error("Error in testsuite/level2/syr2.h: Invalid typename in syr2_()."); } +template +static void syr2_blis_impl( char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, + T* yp, gtint_t incy, T* ap, gtint_t lda ) +{ + if constexpr (std::is_same::value) + ssyr2_blis_impl( &uploa, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else if constexpr (std::is_same::value) + dsyr2_blis_impl( &uploa, &n, alpha, xp, &incx, yp, &incy, ap, &lda ); + else + throw std::runtime_error("Error in testsuite/level2/syr2.h: Invalid typename in syr2_blis_impl()."); +} + template static void cblas_syr2( char storage, char uploa, gtint_t n, T* alpha, T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) @@ -116,11 +129,53 @@ template static void syr2( char storage, char uploa, char conj_x, char conj_y, gtint_t n, T* alpha, T* xp, gtint_t incx, T* yp, gtint_t incy, T* ap, gtint_t lda ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + conj_x = static_cast(std::toupper(static_cast(conj_x))); + conj_y = static_cast(std::toupper(static_cast(conj_y))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char conj_x_cpy = conj_x; + char conj_y_cpy = conj_y; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t incx_cpy = incx; + gtint_t incy_cpy = incy; + gtint_t lda_cpy = lda; + + // Create copy of input arrays so we can check that they are not altered. + T* xp_cpy = nullptr; + gtint_t size_xp; + size_xp = testinghelpers::buff_dim( n, incx ); + { + xp_cpy = new T[size_xp]; + memcpy( xp_cpy, xp, size_xp * sizeof( T ) ); + } + T* yp_cpy = nullptr; + gtint_t size_yp; + size_yp = testinghelpers::buff_dim( n, incy ); + { + yp_cpy = new T[size_yp]; + memcpy( yp_cpy, yp, size_yp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) syr2_( uploa, n, alpha, xp, incx, yp, incy, ap, lda ); else throw std::runtime_error("Error in testsuite/level2/syr2.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + syr2_blis_impl( uploa, n, alpha, xp, incx, yp, incy, ap, lda ); + else + throw std::runtime_error("Error in testsuite/level2/syr2.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_syr2( storage, uploa, n, alpha, xp, incx, yp, incy, ap, lda ); #elif TEST_BLIS_TYPED @@ -128,4 +183,36 @@ static void syr2( char storage, char uploa, char conj_x, char conj_y, gtint_t n, #else throw std::runtime_error("Error in testsuite/level2/syr2.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "conj_x", conj_x, conj_x_cpy ); + computediff( "conj_y", conj_y, conj_y_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + computediff( "incy", incy, incy_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (xp && size_xp > 0) + { + computediff( "x", n, xp, xp_cpy, incx, true ); + delete[] xp_cpy; + } + + if (yp && size_yp > 0) + { + computediff( "y", n, yp, yp_cpy, incy, true ); + delete[] yp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/syr2/test_syr2.h b/gtestsuite/testsuite/level2/syr2/test_syr2.h index a4a623b6ea..a03d8350fb 100644 --- a/gtestsuite/testsuite/level2/syr2/test_syr2.h +++ b/gtestsuite/testsuite/level2/syr2/test_syr2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -74,5 +74,41 @@ void test_syr2( char storage, char uploa, char conjx, char conjy, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, n, n, a.data(), a_ref.data(), lda, thresh ); + computediff( "A", storage, n, n, a.data(), a_ref.data(), lda, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class syr2GenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + char conjy = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t incy = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/trmv/IIT_ERS/trmv_IIT_ERS_test.cpp b/gtestsuite/testsuite/level2/trmv/IIT_ERS/trmv_IIT_ERS_test.cpp new file mode 100644 index 0000000000..59a648d786 --- /dev/null +++ b/gtestsuite/testsuite/level2/trmv/IIT_ERS/trmv_IIT_ERS_test.cpp @@ -0,0 +1,329 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "level2/trmv/test_trmv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" +#include "common/wrong_inputs_helpers.h" +#include +#include +#include + +template +class trmv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(trmv_IIT_ERS, TypeParam); + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +/** + * @brief Test trmv when STORAGE argument is incorrect + * when info == 1 + * + */ +TYPED_TEST(trmv_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( 'x', UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( 'x', UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for trmv): + 1. When UPLO != 'L' || UPLO != 'U' (info = 1) + 2. When TRANS != 'N' || TRANS != 'T' || TRANS != 'C' (info = 2) + 3. When DIAG != 'U' || DIAG != 'N' (info = 3) + 4. When n < 0 (info = 4) + 5. When lda < N (info = 6) + 6. When incx == 0 (info = 8) +*/ + + +/** + * @brief Test trmv when UPLO argument is incorrect + * when info == 1 + * + */ +TYPED_TEST(trmv_IIT_ERS, invalid_UPLO) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, 'A', TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, 'A', TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +/** + * @brief Test trmv when TRANS argument is incorrect + * when info == 2 + * + */ +TYPED_TEST(trmv_IIT_ERS, invalid_TRANS) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, 'A', DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, 'A', DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif +} + +/** + * @brief Test trmv when DIAG argument is incorrect + * when info == 3 + */ +TYPED_TEST(trmv_IIT_ERS, invalid_DIAG) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, TRANS, 'A', N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, TRANS, 'A', N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif +} + +/** + * @brief Test trmv when N is negative + * when info == 4 + */ +TYPED_TEST(trmv_IIT_ERS, invalid_n) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, TRANS, DIAG, -1, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, TRANS, DIAG, -1, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif +} + + +/** + * @brief Test trmv when lda < max(1, N) + * when info == 6 + */ +TYPED_TEST(trmv_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA - 1, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA - 1, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif +} + +/** + * @brief Test trmv when INCX == 0 + * when info == 8 + */ +TYPED_TEST(trmv_IIT_ERS, invalid_incx) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, 0); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), 0); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif +} + + +/* + Early Return Scenarios(ERS) : + + The trmv API is expected to return early in the following cases: + + 1. When n == 0. + +*/ + +/** + * @brief Test trmv when N is zero + */ +TYPED_TEST(trmv_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trmv( STORAGE, UPLO, TRANS, DIAG, 0, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trmv( STORAGE, UPLO, TRANS, DIAG, 0, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif diff --git a/gtestsuite/testsuite/level2/trmv/ctrmv/ctrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ctrmv/ctrmv_generic.cpp new file mode 100644 index 0000000000..ac9065f5ec --- /dev/null +++ b/gtestsuite/testsuite/level2/trmv/ctrmv/ctrmv_generic.cpp @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trmv/test_trmv.h" + +class ctrmvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( ctrmvGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trmv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.0; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ctrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(scomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,scomplex{6.1, -2.9}, scomplex{-3.3, -1.4} + ,scomplex{-1.0, 0.0}, scomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + ctrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(scomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,scomplex{6.1, -2.9}, scomplex{-3.3, -1.4} + ,scomplex{-1.0, 0.0}, scomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp deleted file mode 100644 index a82fafcc2b..0000000000 --- a/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trmv.h" - -class ctrmvTest : - public ::testing::TestWithParam> {}; - -TEST_P(ctrmvTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class ctrmvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ctrmv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ctrmv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ctrmv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ctrmvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, 0.0} -#ifdef TEST_BLIS_TYPED - , scomplex{1.0, -2.0} -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of a - ), - ::ctrmvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trmv/dtrmv/dtrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/dtrmv/dtrmv_generic.cpp new file mode 100644 index 0000000000..8d090a2373 --- /dev/null +++ b/gtestsuite/testsuite/level2/trmv/dtrmv/dtrmv_generic.cpp @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trmv/test_trmv.h" + +class dtrmvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( dtrmvGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trmv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.0; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dtrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + dtrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp deleted file mode 100644 index e7e9e325b9..0000000000 --- a/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trmv.h" - -class dtrmvTest : - public ::testing::TestWithParam> {}; - -TEST_P(dtrmvTest, RandomData) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 20*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class dtrmvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dtrmv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dtrmv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dtrmv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - dtrmvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','t'), // transa - ::testing::Values('n','u'), // diaga - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 -#ifdef TEST_BLIS_TYPED - , -2.0 -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a - ), - ::dtrmvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trmv/strmv/strmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/strmv/strmv_generic.cpp new file mode 100644 index 0000000000..cae4c44dc2 --- /dev/null +++ b/gtestsuite/testsuite/level2/trmv/strmv/strmv_generic.cpp @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trmv/test_trmv.h" + +class strmvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( strmvGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trmv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.0; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + strmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + strmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp deleted file mode 100644 index 470e556814..0000000000 --- a/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trmv.h" - -class strmvTest : - public ::testing::TestWithParam> {}; - -TEST_P(strmvTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class strmvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "strmv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_strmv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_strmv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - strmvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 -#ifdef TEST_BLIS_TYPED - , -2.0 -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of a - ), - ::strmvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trmv/test_trmv.h b/gtestsuite/testsuite/level2/trmv/test_trmv.h index d59f4412f7..a1b829edb6 100644 --- a/gtestsuite/testsuite/level2/trmv/test_trmv.h +++ b/gtestsuite/testsuite/level2/trmv/test_trmv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,36 +39,147 @@ #include "inc/check_error.h" #include #include +#include "common/testing_helpers.h" template -void test_trmv( char storage, char uploa, char transa, char diaga, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, double thresh ) +void test_trmv( + char storage, + char uploa, + char transa, + char diaga, + gtint_t n, + T alpha, + gtint_t lda_inc, + gtint_t incx, + double thresh, + bool is_memory_test = false, + bool is_evt_test = false, + T evt_x = T{0}, + T evt_a = T{0} + ) { + using RT = typename testinghelpers::type_info::real_type; // Compute the leading dimensions for matrix size calculation. gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, n, n, lda ); - std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); - testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); + dim_t size_a = testinghelpers::matsize(storage, transa, n, n, lda) * sizeof(T); - // Create a copy of c so that we can check reference results. - std::vector x_ref(x); + // Buffers for A matrix and X vector are always unaligned + testinghelpers::ProtectedBuffer a(size_a, false, is_memory_test ); + testinghelpers::datagenerators::randomgenerators( 0, 1, storage, n, n, (T*)(a.greenzone_1), transa, lda ); + + dim_t size_x = testinghelpers::buff_dim(n, incx) * sizeof(T); + testinghelpers::ProtectedBuffer x(size_x, false, is_memory_test ); + testinghelpers::datagenerators::randomgenerators( 1, 3, n, incx, (T*)(x.greenzone_1) ); + + T* a_ptr = (T*)(a.greenzone_1); + T* x_ptr = (T*)(x.greenzone_1); + + // Make A matix diagonal dominant to make sure that algorithm doesn't diverge + // This makes sure that the trmv problem is solvable + for ( dim_t a_dim = 0; a_dim < n; ++a_dim ) + { + a_ptr[ a_dim + (a_dim* lda) ] = a_ptr[ a_dim + (a_dim* lda) ] + T{RT(n)}; + } + + // add extreme values to the X vector + if ( is_evt_test ) + { + x_ptr[ (rand() % n) * std::abs(incx) ] = evt_x; + } + + // add extreme values to the A matrix + if ( is_evt_test ) + { + dim_t n_idx = rand() % n; + dim_t m_idx = (std::max)((dim_t)0, n_idx - 1); + a_ptr[ m_idx + (n_idx * lda) ] = evt_a; + a_ptr[ m_idx + (m_idx *lda) ] = evt_a; + } + + // skipped making A triangular + // A matrix being a non triangular matrix could be a better test + // because we are exepcted to read only from the upper or lower triangular + // part of the data, contents of the rest of the matrix should not change the + // result. + // testinghelpers::make_triangular( storage, uploa, n, a_ptr, lda ); + + // Create a copy of x so that we can check reference results. + std::vector x_ref(testinghelpers::buff_dim(n, incx)); + memcpy(x_ref.data(), x_ptr, size_x); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - trmv( storage, uploa, transa, diaga, n, &alpha, a.data(), lda, x.data(), incx ); + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + trmv( storage, uploa, transa, diaga, n, &alpha, a_ptr, lda, x_ptr, incx ); + if ( is_memory_test ) + { + memcpy(a.greenzone_2, a.greenzone_1, size_a); + memcpy(x.greenzone_2, x_ref.data(), size_x); + trmv( storage, uploa, transa, diaga, n, &alpha, (T*)a.greenzone_2, lda, (T*)x.greenzone_2, incx ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_trmv( storage, uploa, transa, diaga, n, &alpha, a.data(), lda, x_ref.data(), incx ); + testinghelpers::ref_trmv( storage, uploa, transa, diaga, n, &alpha, a_ptr, lda, x_ref.data(), incx ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( n, x.data(), x_ref.data(), incx, thresh ); + computediff( "x", n, x_ptr, x_ref.data(), incx, thresh, is_evt_test ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class trmvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t lda_inc = std::get<7>(str.param); + bool is_mem_test = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diaga_" + std::string(&diaga, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name = str_name + (is_mem_test ? "_mem_test_enabled" : "_mem_test_disabled"); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/trmv/trmv.h b/gtestsuite/testsuite/level2/trmv/trmv.h index 8ee3750a62..bcebc97997 100644 --- a/gtestsuite/testsuite/level2/trmv/trmv.h +++ b/gtestsuite/testsuite/level2/trmv/trmv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation @@ -69,6 +70,22 @@ static void trmv_( char uploa, char transa, char diaga, gtint_t n, throw std::runtime_error("Error in testsuite/level2/trmv.h: Invalid typename in trmv_()."); } +template +static void trmv_blis_impl( char uploa, char transa, char diaga, gtint_t n, + T *ap, gtint_t lda, T *xp, gtint_t incx ) +{ + if constexpr (std::is_same::value) + strmv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + dtrmv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + ctrmv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + ztrmv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else + throw std::runtime_error("Error in testsuite/level2/trmv.h: Invalid typename in trmv_blis_impl()."); +} + template static void cblas_trmv( char storage, char uploa, char transa, char diaga, gtint_t n, T *ap, gtint_t lda, T *xp, gtint_t incx ) @@ -134,11 +151,39 @@ template static void trmv( char storage, char uploa, char transa, char diaga, gtint_t n, T *alpha, T *ap, gtint_t lda, T *xp, gtint_t incx ) { -#if (defined TEST_BLAS || defined TEST_CBLAS) +#if (defined TEST_BLAS_LIKE || defined TEST_CBLAS) T one; testinghelpers::initone(one); #endif +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + transa = static_cast(std::toupper(static_cast(transa))); + diaga = static_cast(std::toupper(static_cast(diaga))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char transa_cpy = transa; + char diaga_cpy = diaga; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, n, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if(( storage == 'c' || storage == 'C' )) if( *alpha == one ) @@ -147,6 +192,14 @@ static void trmv( char storage, char uploa, char transa, char diaga, throw std::runtime_error("Error in testsuite/level2/trmv.h: BLAS interface cannot be tested for alpha != one."); else throw std::runtime_error("Error in testsuite/level2/trmv.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if(( storage == 'c' || storage == 'C' )) + if( *alpha == one ) + trmv_blis_impl( uploa, transa, diaga, n, ap, lda, xp, incx ); + else + throw std::runtime_error("Error in testsuite/level2/trmv.h: BLAS_BLIS_IMPL interface cannot be tested for alpha != one."); + else + throw std::runtime_error("Error in testsuite/level2/trmv.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS if( *alpha == one ) cblas_trmv( storage, uploa, transa, diaga, n, ap, lda, xp, incx ); @@ -157,4 +210,29 @@ static void trmv( char storage, char uploa, char transa, char diaga, #else throw std::runtime_error("Error in testsuite/level2/trmv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "diaga", diaga, diaga_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, n, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/trmv/ztrmv/ztrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ztrmv/ztrmv_generic.cpp new file mode 100644 index 0000000000..3248ec7167 --- /dev/null +++ b/gtestsuite/testsuite/level2/trmv/ztrmv/ztrmv_generic.cpp @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trmv/test_trmv.h" + +class ztrmvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( ztrmvGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trmv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.0; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ztrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(dcomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,dcomplex{6.1, -2.9}, dcomplex{-3.3, -1.4} + ,dcomplex{-1.0, 0.0}, dcomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + ztrmvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(dcomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,dcomplex{6.1, -2.9}, dcomplex{-3.3, -1.4} + ,dcomplex{-1.0, 0.0}, dcomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trmvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp deleted file mode 100644 index 1fb53d2b7d..0000000000 --- a/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trmv.h" - -class ztrmvTest : - public ::testing::TestWithParam> {}; - -TEST_P(ztrmvTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class ztrmvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ztrmv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ztrmv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ztrmv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ztrmvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, 0.0} -#ifdef TEST_BLIS_TYPED - ,dcomplex{1.0, -2.0} -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a - ), - ::ztrmvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trsv/IIT_ERS/trsv_IIT_ERS_test.cpp b/gtestsuite/testsuite/level2/trsv/IIT_ERS/trsv_IIT_ERS_test.cpp new file mode 100644 index 0000000000..46f8ad8df8 --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/IIT_ERS/trsv_IIT_ERS_test.cpp @@ -0,0 +1,329 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "level2/trsv/test_trsv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" +#include "common/wrong_inputs_helpers.h" +#include +#include +#include + +template +class trsv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(trsv_IIT_ERS, TypeParam); + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +/** + * @brief Test TRSV when STORAGE argument is incorrect + * when info == 1 + * + */ +TYPED_TEST(trsv_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( 'x', UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( 'x', UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for TRSV): + 1. When UPLO != 'L' || UPLO != 'U' (info = 1) + 2. When TRANS != 'N' || TRANS != 'T' || TRANS != 'C' (info = 2) + 3. When DIAG != 'U' || DIAG != 'N' (info = 3) + 4. When n < 0 (info = 4) + 5. When lda < N (info = 6) + 6. When incx == 0 (info = 8) +*/ + + +/** + * @brief Test TRSV when UPLO argument is incorrect + * when info == 1 + * + */ +TYPED_TEST(trsv_IIT_ERS, invalid_UPLO) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, 'A', TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, 'A', TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +/** + * @brief Test TRSV when TRANS argument is incorrect + * when info == 2 + * + */ +TYPED_TEST(trsv_IIT_ERS, invalid_TRANS) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, 'A', DIAG, N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, 'A', DIAG, N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif +} + +/** + * @brief Test TRSV when DIAG argument is incorrect + * when info == 3 + */ +TYPED_TEST(trsv_IIT_ERS, invalid_DIAG) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, TRANS, 'A', N, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, TRANS, 'A', N, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif +} + +/** + * @brief Test TRSV when N is negative + * when info == 4 + */ +TYPED_TEST(trsv_IIT_ERS, invalid_n) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, TRANS, DIAG, -1, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, TRANS, DIAG, -1, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif +} + + +/** + * @brief Test TRSV when lda < max(1, N) + * when info == 6 + */ +TYPED_TEST(trsv_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA - 1, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA - 1, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif +} + +/** + * @brief Test TRSV when INCX == 0 + * when info == 8 + */ +TYPED_TEST(trsv_IIT_ERS, invalid_incx) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, nullptr, LDA, nullptr, 0); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, TRANS, DIAG, N, &alpha, a.data(), LDA, x.data(), 0); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif +} + + +/* + Early Return Scenarios(ERS) : + + The TRSV API is expected to return early in the following cases: + + 1. When n == 0. + +*/ + +/** + * @brief Test TRSV when N is zero + */ +TYPED_TEST(trsv_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + T alpha = T{1}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsv( STORAGE, UPLO, TRANS, DIAG, 0, &alpha, nullptr, LDA, nullptr, INC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix( 1, 5, STORAGE, TRANS, M, N, LDA); + std::vector x = testinghelpers::get_random_vector(0, 1, N, INC); + std::vector x_ref(x); + + trsv( STORAGE, UPLO, TRANS, DIAG, 0, &alpha, a.data(), LDA, x.data(), INC); + computediff( "x", N, x.data(), x_ref.data(), INC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif diff --git a/gtestsuite/testsuite/level2/trsv/ctrsv/ctrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ctrsv/ctrsv_generic.cpp new file mode 100644 index 0000000000..b38a731c94 --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/ctrsv/ctrsv_generic.cpp @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class ctrsvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( ctrsvGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.5; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ctrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(scomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,scomplex{6.1, -2.9}, scomplex{-3.3, -1.4} + ,scomplex{-1.0, 0.0}, scomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + ctrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(scomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,scomplex{6.1, -2.9}, scomplex{-3.3, -1.4} + ,scomplex{-1.0, 0.0}, scomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp deleted file mode 100644 index 1639e7202c..0000000000 --- a/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsv.h" - -class ctrsvTest : - public ::testing::TestWithParam> {}; - -TEST_P(ctrsvTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 5*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class ctrsvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ctrsv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ctrsv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ctrsv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ctrsvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{1.0, 0.0} -#ifdef TEST_BLIS_TYPED - , scomplex{1.0, -2.0} -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a - ), - ::ctrsvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_evt.cpp b/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_evt.cpp new file mode 100644 index 0000000000..5bc011dee8 --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_evt.cpp @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class dtrsvEVT : + public ::testing::TestWithParam> {}; // ld_inc + +TEST_P( dtrsvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // extreme value for x + double xexval = std::get<7>(GetParam()); + // extreme value for A + double aexval = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = 2*n*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, false, true, xexval, aexval); +} + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +INSTANTIATE_TEST_SUITE_P( + Native, + dtrsvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15) + ), // n (random values) + ::testing::Values( 1.0 +#ifdef TEST_BLIS_TYPED + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-2), gtint_t(-1), + gtint_t( 1), gtint_t( 2)), // stride size for x + ::testing::Values(AOCL_NAN, -AOCL_INF, AOCL_INF, 1 /*,0 <-fail*/),// exception value for x + ::testing::Values(AOCL_NAN, -AOCL_INF, AOCL_INF, 0), // exception value for A + ::testing::Values(gtint_t(0), gtint_t(10)) // increment to the leading dim of a + ), + ::trsvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_generic.cpp new file mode 100644 index 0000000000..cf90ffda0a --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/dtrsv/dtrsv_generic.cpp @@ -0,0 +1,159 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class dtrsvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( dtrsvGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 7.5; + #ifdef REF_IS_MKL + adj = 8.3; + #endif +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + dtrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + dtrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp deleted file mode 100644 index 3ebf2f6076..0000000000 --- a/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsv.h" - -class dtrsvTest : - public ::testing::TestWithParam> {}; - -TEST_P(dtrsvTest, RandomData) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 100*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class dtrsvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dtrsv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dtrsv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dtrsv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - dtrsvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 -#ifdef TEST_BLIS_TYPED - , -2.0 -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a - ), - ::dtrsvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trsv/strsv/strsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/strsv/strsv_generic.cpp new file mode 100644 index 0000000000..7af25d85df --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/strsv/strsv_generic.cpp @@ -0,0 +1,162 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class strsvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( strsvGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 9.0; + #ifdef REF_IS_MKL + adj = 12.0; + #endif +#else + double adj = 12.0; + #ifdef REF_IS_MKL + adj = 14.0; + #endif +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + strsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + strsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(1.0 // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + , -2.2, 5.4, -1.0, 0.0 +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp deleted file mode 100644 index 201223b134..0000000000 --- a/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsv.h" - -class strsvTest : - public ::testing::TestWithParam> {}; - -TEST_P(strsvTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 20*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class strsvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "strsv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_strsv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_strsv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - strsvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values( 1.0 -#ifdef TEST_BLIS_TYPED - , -2.0 -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of a - ), - ::strsvTestPrint() - ); diff --git a/gtestsuite/testsuite/level2/trsv/test_trsv.h b/gtestsuite/testsuite/level2/trsv/test_trsv.h index 2266397200..dfb7a685ae 100644 --- a/gtestsuite/testsuite/level2/trsv/test_trsv.h +++ b/gtestsuite/testsuite/level2/trsv/test_trsv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,36 +39,180 @@ #include "inc/check_error.h" #include #include +#include "common/testing_helpers.h" template -void test_trsv( char storage, char uploa, char transa, char diaga, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, double thresh ) +void test_trsv( + char storage, + char uploa, + char transa, + char diaga, + gtint_t n, + T alpha, + gtint_t lda_inc, + gtint_t incx, + double thresh, + bool is_memory_test = false, + bool is_evt_test = false, + T evt_x = T{0}, + T evt_a = T{0} + ) { + using RT = typename testinghelpers::type_info::real_type; // Compute the leading dimensions for matrix size calculation. gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( 1, 5, storage, transa, n, n, lda ); - std::vector x = testinghelpers::get_random_vector( 1, 3, n, incx ); - testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); + dim_t size_a = testinghelpers::matsize(storage, transa, n, n, lda) * sizeof(T); - // Create a copy of c so that we can check reference results. - std::vector x_ref(x); + // Buffers for A matrix and X vector are always unaligned + testinghelpers::ProtectedBuffer a(size_a, false, is_memory_test ); + testinghelpers::datagenerators::randomgenerators( 0, 1, storage, n, n, (T*)(a.greenzone_1), transa, lda ); + + dim_t size_x = testinghelpers::buff_dim(n, incx) * sizeof(T); + testinghelpers::ProtectedBuffer x(size_x, false, is_memory_test ); + testinghelpers::datagenerators::randomgenerators( 1, 3, n, incx, (T*)(x.greenzone_1) ); + + T* a_ptr = (T*)(a.greenzone_1); + T* x_ptr = (T*)(x.greenzone_1); + + // Make A matix diagonal dominant to make sure that algorithm doesn't diverge + // This makes sure that the TRSV problem is solvable + for ( dim_t a_dim = 0; a_dim < n; ++a_dim ) + { + a_ptr[ a_dim + (a_dim* lda) ] = a_ptr[ a_dim + (a_dim* lda) ] + T{RT(n)}; + } + + // add extreme values to the X vector + if ( is_evt_test ) + { + x_ptr[ (rand() % n) * std::abs(incx) ] = evt_x; + } + + // add extreme values to the A matrix + if ( is_evt_test ) + { + dim_t n_idx = rand() % n; + dim_t m_idx = (std::max)((dim_t)0, n_idx - 1); + a_ptr[ m_idx + (n_idx * lda) ] = evt_a; + a_ptr[ m_idx + (m_idx *lda) ] = evt_a; + } + + // skipped making A triangular + // A matrix being a non triangular matrix could be a better test + // because we are exepcted to read only from the upper or lower triangular + // part of the data, contents of the rest of the matrix should not change the + // result. + // testinghelpers::make_triangular( storage, uploa, n, a_ptr, lda ); + + // Create a copy of x so that we can check reference results. + std::vector x_ref(testinghelpers::buff_dim(n, incx)); + memcpy(x_ref.data(), x_ptr, size_x); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - trsv( storage, uploa, transa, diaga, n, &alpha, a.data(), lda, x.data(), incx ); + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + trsv( storage, uploa, transa, diaga, n, &alpha, a_ptr, lda, x_ptr, incx ); + if ( is_memory_test ) + { + memcpy(a.greenzone_2, a.greenzone_1, size_a); + memcpy(x.greenzone_2, x_ref.data(), size_x); + trsv( storage, uploa, transa, diaga, n, &alpha, (T*)a.greenzone_2, lda, (T*)x.greenzone_2, incx ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_trsv( storage, uploa, transa, diaga, n, &alpha, a.data(), lda, x_ref.data(), incx ); + testinghelpers::ref_trsv( storage, uploa, transa, diaga, n, &alpha, a_ptr, lda, x_ref.data(), incx ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( n, x.data(), x_ref.data(), incx, thresh ); + computediff( "x", n, x_ptr, x_ref.data(), incx, thresh, is_evt_test ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class trsvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + gtint_t lda_inc = std::get<7>(str.param); + bool is_mem_test = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diaga_" + std::string(&diaga, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name = str_name + (is_mem_test ? "_mem_test_enabled" : "_mem_test_disabled"); + return str_name; + } +}; + +template +class trsvEVTPrint +{ +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploa = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t incx = std::get<6>(str.param); + T xexval = std::get<7>(str.param); + T aexval = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diaga_" + std::string(&diaga, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_ex_x_" + testinghelpers::get_value_string(xexval); + str_name = str_name + "_ex_a_" + testinghelpers::get_value_string(aexval); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level2/trsv/trsv.h b/gtestsuite/testsuite/level2/trsv/trsv.h index 65ca33112a..95b23f1103 100644 --- a/gtestsuite/testsuite/level2/trsv/trsv.h +++ b/gtestsuite/testsuite/level2/trsv/trsv.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -69,6 +70,22 @@ static void trsv_( char uploa, char transa, char diaga, gtint_t n, throw std::runtime_error("Error in testsuite/level2/trsv.h: Invalid typename in trsv_()."); } +template +static void trsv_blis_impl( char uploa, char transa, char diaga, gtint_t n, + T *ap, gtint_t lda, T *xp, gtint_t incx ) +{ + if constexpr (std::is_same::value) + strsv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + dtrsv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + ctrsv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else if constexpr (std::is_same::value) + ztrsv_blis_impl( &uploa, &transa, &diaga, &n, ap, &lda, xp, &incx ); + else + throw std::runtime_error("Error in testsuite/level2/trsv.h: Invalid typename in trsv_blis_impl()."); +} + template static void cblas_trsv( char storage, char uploa, char transa, char diaga, gtint_t n, T *ap, gtint_t lda, T *xp, gtint_t incx ) @@ -134,11 +151,39 @@ template static void trsv( char storage, char uploa, char transa, char diaga, gtint_t n, T *alpha, T *ap, gtint_t lda, T *xp, gtint_t incx ) { -#if (defined TEST_BLAS || defined TEST_CBLAS) +#if (defined TEST_BLAS_LIKE || defined TEST_CBLAS) T one; testinghelpers::initone(one); #endif +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uploa = static_cast(std::toupper(static_cast(uploa))); + transa = static_cast(std::toupper(static_cast(transa))); + diaga = static_cast(std::toupper(static_cast(diaga))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uploa_cpy = uploa; + char transa_cpy = transa; + char diaga_cpy = diaga; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t incx_cpy = incx; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, n, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if(( storage == 'c' || storage == 'C' )) if( *alpha == one ) @@ -147,6 +192,14 @@ static void trsv( char storage, char uploa, char transa, char diaga, throw std::runtime_error("Error in testsuite/level2/trsv.h: BLAS interface cannot be tested for alpha != one."); else throw std::runtime_error("Error in testsuite/level2/trsv.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if(( storage == 'c' || storage == 'C' )) + if( *alpha == one ) + trsv_blis_impl( uploa, transa, diaga, n, ap, lda, xp, incx ); + else + throw std::runtime_error("Error in testsuite/level2/trsv.h: BLAS_BLIS_IMPL interface cannot be tested for alpha != one."); + else + throw std::runtime_error("Error in testsuite/level2/trsv.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS if( *alpha == one ) cblas_trsv( storage, uploa, transa, diaga, n, ap, lda, xp, incx ); @@ -157,4 +210,29 @@ static void trsv( char storage, char uploa, char transa, char diaga, #else throw std::runtime_error("Error in testsuite/level2/trsv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "diaga", diaga, diaga_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "incx", incx, incx_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, n, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_evt.cpp b/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_evt.cpp new file mode 100644 index 0000000000..5aa0c51e61 --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_evt.cpp @@ -0,0 +1,151 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class DISABLED_ztrsvEVT : + public ::testing::TestWithParam> {}; // ld_inc + +TEST_P( DISABLED_ztrsvEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // stride size for x: + gtint_t incx = std::get<6>(GetParam()); + // extreme value for x + dcomplex xexval = std::get<7>(GetParam()); + // extreme value for A + dcomplex aexval = std::get<8>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = 2*n*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, false, true, xexval, aexval); +} + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +INSTANTIATE_TEST_SUITE_P( + Native, + DISABLED_ztrsvEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(32), + gtint_t(24), + gtint_t(8), + gtint_t(4), + gtint_t(2), + gtint_t(1), + gtint_t(15) + ), // n + ::testing::Values(dcomplex{1.0, 0.0} +#ifdef TEST_BLIS_TYPED + ,dcomplex{6.1, -2.9}, dcomplex{-3.3, -1.4}, + dcomplex{-1.0, 0.0}, dcomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-2), gtint_t(-1), + gtint_t( 1), gtint_t( 2)), // stride size for x + ::testing::Values( + dcomplex{AOCL_NAN, 2.1}, + dcomplex{2.1, AOCL_NAN}, + dcomplex{AOCL_NAN, AOCL_INF}, + // dcomplex{2.3, AOCL_INF}, // fail + // dcomplex{AOCL_INF, 2.3}, // fail + // dcomplex{0.0, AOCL_INF}, // fail + // dcomplex{AOCL_INF, 0.0}, // fail + // dcomplex{0.0, -AOCL_INF}, // fail + // dcomplex{-AOCL_INF, 0.0}, // fail + dcomplex{1, 0} ), // exception value for x + ::testing::Values( + dcomplex{AOCL_NAN, 3.2}, + dcomplex{2.1, AOCL_NAN}, + dcomplex{AOCL_NAN, AOCL_INF}, + // dcomplex{2.3, AOCL_INF}, // fail + // dcomplex{AOCL_INF, 6.1}, // fail + dcomplex{1, 0}), // exception value for A + ::testing::Values(gtint_t(0), gtint_t(10)) // increment to the leading dim of a + ), + ::trsvEVTPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_generic.cpp new file mode 100644 index 0000000000..f9b0a9f87e --- /dev/null +++ b/gtestsuite/testsuite/level2/trsv/ztrsv/ztrsv_generic.cpp @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level2/trsv/test_trsv.h" + +class ztrsvGeneric : + public ::testing::TestWithParam> {}; // is memory test + +TEST_P( ztrsvGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix A is u,l + char uploa = std::get<1>(GetParam()); + // denotes whether matrix A is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix diag is u,n + char diaga = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // specifies alpha value + T alpha = std::get<5>(GetParam()); + // increment for x (incx): + gtint_t incx = std::get<6>(GetParam()); + // lda increment. + // If increment is zero, then the array size matches the matrix size. + // If increment are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<7>(GetParam()); + bool is_mem_test = std::get<8>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 2.0; +#endif + if (n == 0 || alpha == T{0.0}) + thresh = 0.0; + else + if(alpha == T{1.0}) + thresh = adj*2*n*testinghelpers::getEpsilon(); + else + thresh = adj*3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, is_mem_test ); +} + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxSmall, + ztrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Range(gtint_t(1),gtint_t(21),1), // n + ::testing::Values(dcomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,dcomplex{6.1, -2.9}, dcomplex{-3.3, -1.4} + ,dcomplex{-1.0, 0.0}, dcomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + BlackboxMedium, + ztrsvGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uploa + ::testing::Values('n','t','c'), // transa + ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG + ::testing::Values(gtint_t(25), + gtint_t(33), + gtint_t(98), + gtint_t(173), + gtint_t(211) + ), // n + ::testing::Values(dcomplex{1.0, 0.0} // Only blis typed api supports +#ifdef TEST_BLIS_TYPED // values of alpha other than 1 + ,dcomplex{6.1, -2.9}, dcomplex{-3.3, -1.4} + ,dcomplex{-1.0, 0.0}, dcomplex{0.0, 0.0} +#endif + ), // alpha + ::testing::Values(gtint_t(-1),gtint_t(1), gtint_t(33)), // incx + ::testing::Values(gtint_t(0), gtint_t(11)), // increment to the leading dim of a + ::testing::Values(false, true) // is memory test + ), + ::trsvGenericPrint() + ); diff --git a/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp deleted file mode 100644 index dc8b004575..0000000000 --- a/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsv.h" - -class ztrsvTest : - public ::testing::TestWithParam> {}; - -TEST_P(ztrsvTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is u,l - char uploa = std::get<1>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<2>(GetParam()); - // denotes whether matrix diag is u,n - char diaga = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // specifies alpha value - T alpha = std::get<5>(GetParam()); - // stride size for x: - gtint_t incx = std::get<6>(GetParam()); - // lda increment. - // If increment is zero, then the array size matches the matrix size. - // If increment are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<7>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); -} - -class ztrsvTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uploa = std::get<1>(str.param); - char transa = std::get<2>(str.param); - char diaga = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - gtint_t incx = std::get<6>(str.param); - gtint_t ld_inc = std::get<7>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ztrsv_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ztrsv"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ztrsv"; -#endif - str_name = str_name + "_" + sfm; - str_name = str_name + "_" + uploa+transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + std::to_string(ld_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ztrsvTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('u','l'), // uploa - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=NONUNIT_DIAG u=UNIT_DIAG - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(dcomplex{1.0, 0.0} -#ifdef TEST_BLIS_TYPED - ,dcomplex{1.0, -2.0} -#endif - ), // alpha - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of a - ), - ::ztrsvTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm/IIT_ERS/gemm_IIT_ERS.cpp b/gtestsuite/testsuite/level3/gemm/IIT_ERS/gemm_IIT_ERS.cpp new file mode 100644 index 0000000000..88e471fda3 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/IIT_ERS/gemm_IIT_ERS.cpp @@ -0,0 +1,702 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/gemm/test_gemm.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class gemm_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS calls for GEMM +TYPED_TEST_SUITE(gemm_IIT_ERS, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +// When info == 1 +TYPED_TEST(gemm_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemm( 'x', TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm( 'x', TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM): + 1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 1) + 2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 2) + 3. When m < 0 (info = 3) + 4. When n < 0 (info = 4) + 5. When k < 0 (info = 5) + 6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value + 7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value + 8. When ldc < max(1, n) (info = 13) + +*/ + +// When info == 1 +TYPED_TEST(gemm_IIT_ERS, invalid_transa) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, 'p', TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, 'p', TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm( STORAGE, 'p', TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +// When info == 2 +TYPED_TEST(gemm_IIT_ERS, invalid_transb) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, 'p', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, 'p', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for TRANS value for B. + gemm( STORAGE, TRANS, 'p', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif +} + +// When info == 3 +TYPED_TEST(gemm_IIT_ERS, m_lt_zero) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, -1, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for m. + gemm( STORAGE, TRANS, TRANS, -1, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif +} + +// When info == 4 +TYPED_TEST(gemm_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, -1, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for n. + gemm( STORAGE, TRANS, TRANS, M, -1, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif +} + +// When info == 5 +TYPED_TEST(gemm_IIT_ERS, k_lt_zero) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, -1, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for k. + gemm( STORAGE, TRANS, TRANS, M, N, -1, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif +} + +// When info == 8 +TYPED_TEST(gemm_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA - 1, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for lda. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA - 1, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif +} + +// When info == 10 +TYPED_TEST(gemm_IIT_ERS, invalid_ldb) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB - 1, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 10 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for ldb. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB - 1, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 10 ); +#endif +} + +// When info == 13 +TYPED_TEST(gemm_IIT_ERS, invalid_ldc) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC - 1 ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 13 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for ldc. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC - 1 ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 13 ); +#endif +} + +/* + Early Return Scenarios(ERS) : + + The GEMM API is expected to return early in the following cases: + + 1. When m == 0. + 2. When n == 0. + 3. When (alpha == 0 or k == 0) and beta == 1. + 4. When alpha == 0 and beta == 0, set C = 0 only + 5. When alpha == 0 and beta /= 0 or 1, scale C by beta only + +*/ + +// When m is 0 +TYPED_TEST(gemm_IIT_ERS, m_eq_zero) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, 0, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemm( STORAGE, TRANS, TRANS, 0, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When n is 0 +TYPED_TEST(gemm_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, 0, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemm( STORAGE, TRANS, TRANS, M, 0, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When alpha is 0 and beta is 1 +TYPED_TEST(gemm_IIT_ERS, alpha_zero_beta_one) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When k is 0 and beta is 1 +TYPED_TEST(gemm_IIT_ERS, k_zero_beta_one) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#else + gemm( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemm( STORAGE, TRANS, TRANS, M, N, 0, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and zero beta - set C to 0 +TYPED_TEST(gemm_IIT_ERS, ZeroAlpha_ZeroBeta) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initzero( beta ); + + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + // Copy so that we check that the elements of C are not modified. + std::vector c2(c); + std::vector zero_mat = testinghelpers::get_random_matrix(0, 0, STORAGE, 'n', M, N, LDB); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, c2.data(), LDC ); + computediff( "C", STORAGE, N, N, c2.data(), zero_mat.data(), LDC); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), zero_mat.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and non-zero/non-unit beta - scale C only +TYPED_TEST(gemm_IIT_ERS, ZeroAlpha_OtherBeta) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initzero( alpha ); + beta = T{2.0}; + double thresh = testinghelpers::getEpsilon(); + + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + // Copy so that we check that the elements of C are not modified. + std::vector c2(c); + std::vector c_ref(c); + + testinghelpers::ref_gemm( STORAGE, TRANS, TRANS, M, N, K, alpha, + a.data(), LDA, b.data(), LDB, beta, c_ref.data(), LDC ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, c2.data(), LDC ); + computediff( "C", STORAGE, N, N, c2.data(), c_ref.data(), LDC, thresh); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC, thresh); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#if 0 +/** + * These testcases are disabled as blis aborts for null buffers. + * Once respective blis framework changes are done to simply pass down + * the error to the top level these testcases can be enabled. +*/ +// When a matrix is null +TYPED_TEST(gemm_IIT_ERS, null_a_matrix) +{ + using T = TypeParam; + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + T alpha, beta; + + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When b matrix is null +TYPED_TEST(gemm_IIT_ERS, null_b_matrix) +{ + using T = TypeParam; + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + T alpha, beta; + + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, a.data(), LDA, nullptr, LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} +#endif /* #IF 0 ENDS HERE */ +#endif + diff --git a/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp b/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp deleted file mode 100644 index 9e8ea79d4e..0000000000 --- a/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp +++ /dev/null @@ -1,264 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "common/testing_helpers.h" -#include "gemm.h" -#include "inc/check_error.h" -#include "common/wrong_inputs_helpers.h" - -template -class Gemm_IIT_ERS_Test : public ::testing::Test {}; -typedef ::testing::Types TypeParam; // The supported datatypes from BLAS calls for GEMM -TYPED_TEST_SUITE(Gemm_IIT_ERS_Test, TypeParam); // Defining individual testsuites based on the datatype support. - -// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. -using namespace testinghelpers::IIT; - -#ifdef TEST_BLAS - -/* - Incorrect Input Testing(IIT) - - BLAS exceptions get triggered in the following cases(for GEMM): - 1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 1) - 2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 2) - 3. When m < 0 (info = 3) - 4. When n < 0 (info = 4) - 5. When k < 0 (info = 5) - 6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value - 7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value - 8. When ldc < max(1, n) (info = 13) - -*/ - -// When info == 1 -TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transa) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for TRANS value for A. - gemm( STORAGE, 'p', TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 2 -TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transb) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for TRANS value for B. - gemm( STORAGE, TRANS, 'p', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 3 -TYPED_TEST(Gemm_IIT_ERS_Test, m_lt_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm( STORAGE, TRANS, TRANS, -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 4 -TYPED_TEST(Gemm_IIT_ERS_Test, n_lt_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for n. - gemm( STORAGE, TRANS, TRANS, M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 5 -TYPED_TEST(Gemm_IIT_ERS_Test, k_lt_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for k. - gemm( STORAGE, TRANS, TRANS, M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 8 -TYPED_TEST(Gemm_IIT_ERS_Test, invalid_lda) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for lda. - gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 10 -TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldb) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for ldb. - gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When info == 13 -TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldc) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for ldc. - gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -/* - Early Return Scenarios(ERS) : - - The GEMM API is expected to return early in the following cases: - - 1. When m == 0. - 2. When n == 0. - 3. When (alpha == 0 or k == 0) and beta == 1. - -*/ - -// When m is 0 -TYPED_TEST(Gemm_IIT_ERS_Test, m_eq_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - gemm( STORAGE, TRANS, TRANS, 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When n is 0 -TYPED_TEST(Gemm_IIT_ERS_Test, n_eq_zero) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - gemm( STORAGE, TRANS, TRANS, M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When alpha is 0 and beta is 1 -TYPED_TEST(Gemm_IIT_ERS_Test, alpha_zero_beta_one) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - T alpha, beta; - - testinghelpers::initzero( alpha ); - testinghelpers::initone( beta ); - - gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - -// When k is 0 and beta is 1 -TYPED_TEST(Gemm_IIT_ERS_Test, k_zero_beta_one) -{ - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - T alpha, beta; - - testinghelpers::initone( alpha ); - testinghelpers::initone( beta ); - - gemm( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); -} - - -#endif diff --git a/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_evt.cpp b/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_evt.cpp new file mode 100644 index 0000000000..0f13589234 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_evt.cpp @@ -0,0 +1,490 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +using T = scomplex; + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +class DISABLED_cgemmEVT : + public ::testing::TestWithParam> {}; + +TEST_P( DISABLED_cgemmEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + + // ai, aj, bi, bj, ci, cj - Indices of all Matrices where + // EV to be inserted + gtint_t ai, aj, bi, bj, ci, cj; + + // aex, bex, cex - Exception value(EV) for each Matrix + T aex, bex, cex; + ai = std::get<6>(GetParam()); + aj = std::get<7>(GetParam()); + aex = std::get<8>(GetParam()); + + bi = std::get<9>(GetParam()); + bj = std::get<10>(GetParam()); + bex = std::get<11>(GetParam()); + + ci = std::get<12>(GetParam()); + cj = std::get<13>(GetParam()); + cex = std::get<14>(GetParam()); + + // specifies alpha value + T alpha = std::get<15>(GetParam()); + // specifies beta value + T beta = std::get<16>(GetParam()); + + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, + // the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<17>(GetParam()); + gtint_t ldb_inc = std::get<18>(GetParam()); + gtint_t ldc_inc = std::get<19>(GetParam()); + + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); +} + +/********************************************************************/ +/* Testing ExceptionValue testing for SUP and Native implementation */ +/* of cgemm API */ +/********************************************************************/ +/* Exception Values are AOCL_NAN, AOCL_INF, -AOCL_INF */ +/* 1. Matrix: */ +/* These values are inserted in user provided (i,j)th indices of */ +/* Matrix A, B, C */ +/* 2. Scaling Values: */ +/* These values are inserted as alpha, beta values */ +/********************************************************************/ + +//Failures observed for EV: T{AOCL_INF, 0.0} +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_No_Trans, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(300)), // m + ::testing::Values(gtint_t(210)), // n + ::testing::Values(gtint_t(150)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.2}, T{AOCL_INF, 5.2}, + T{-3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, -2.3}, T{AOCL_INF, 8.9}, + T{-3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 1.3}, T{AOCL_INF, 7.4}, + T{3.3, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-1.0, -2.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{91.0, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{12.0, 2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{12.0, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_Trans, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('t'), // transa + ::testing::Values('t'), // transb + ::testing::Values(gtint_t(300)), // m + ::testing::Values(gtint_t(210)), // n + ::testing::Values(gtint_t(150)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.2}, T{AOCL_INF, -9.0}, + T{-3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(T{AOCL_NAN, -2.3}, T{AOCL_INF, -6.7}, + T{-3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{AOCL_NAN, 1.3}, T{AOCL_INF, 5.6}, + T{3.3, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-1.0, -2.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{12.0, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{12.0, 2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{12.0, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_zeros_And_ExceptionValues, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(200)), // m + ::testing::Values(gtint_t(100)), // n + ::testing::Values(gtint_t(150)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // cexval + ::testing::Values(T{-1.0, -2.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{2.3, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{12.0, 2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{3.2, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_Alpha_Beta, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(210)), // m + ::testing::Values(gtint_t(100)), // n + ::testing::Values(gtint_t(50)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{1.2, 2.3}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{-2.3, -12}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{-0.7, 3.2}), // cexval + ::testing::Values(T{AOCL_NAN, 1.4}, T{AOCL_INF, 7.4}, + T{4.2, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}, + T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // alpha + ::testing::Values(T{AOCL_NAN, 5.2}, T{AOCL_INF, 3.4}, + T{1.6, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}, + T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_No_Trans, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(500)), // m + ::testing::Values(gtint_t(680)), // n + ::testing::Values(gtint_t(370)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 9.3}, T{AOCL_INF, 3.9}, + T{13.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, -5.6}, T{AOCL_INF, -3.1}, + T{9.7, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 7.8}, T{AOCL_INF, -6.7}, + T{-3.6, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-21.0, -12.0}), // alpha + ::testing::Values(T{1.0, 2.13}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_Trans, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('t'), // transa + ::testing::Values('t'), // transb + ::testing::Values(gtint_t(595)), // m + ::testing::Values(gtint_t(880)), // n + ::testing::Values(gtint_t(470)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 9.3}, T{AOCL_INF, -5.6}, + T{13.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, -5.6}, T{AOCL_INF, 3.2}, + T{9.7, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{AOCL_NAN, 7.8}, T{AOCL_INF, -6.7}, + T{-3.6, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-21.0, -12.0}), // alpha + ::testing::Values(T{1.0, 2.13}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_Conj, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('c'), // transa + ::testing::Values('c'), // transb + ::testing::Values(gtint_t(700)), // m + ::testing::Values(gtint_t(990)), // n + ::testing::Values(gtint_t(475)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 9.3}, T{AOCL_INF, -3.2}, + T{13.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, -5.6}, T{AOCL_INF, 5.2}, + T{9.7, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 7.8}, T{AOCL_INF, 7.6}, + T{-3.6, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-21.0, -12.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{9.8, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.0, 2.13}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{4.3, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_zeros_And_ExcpetionValues, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(800)), // m + ::testing::Values(gtint_t(1100)), // n + ::testing::Values(gtint_t(475)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // cexval + ::testing::Values(T{-21.0, -12.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{2.4, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.0, 2.13}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{4.5, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_Alpha_Beta, + DISABLED_cgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(700)), // m + ::testing::Values(gtint_t(1000)), // n + ::testing::Values(gtint_t(475)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{1.12, 12.3}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{-12.3, -2}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{-1.7, -3.12}), // cexval + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 8.9}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}, + T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // alpha + ::testing::Values(T{AOCL_NAN, 5.3}, T{AOCL_INF, 3.5}, + T{2.9, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}, + T{AOCL_NAN, 0}, T{AOCL_INF, 0.0}, + T{0, AOCL_NAN}, T{0, -AOCL_INF}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_generic.cpp new file mode 100644 index 0000000000..759d33231d --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/cgemm/cgemm_generic.cpp @@ -0,0 +1,253 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class cgemmGeneric : + public ::testing::TestWithParam> {}; +TEST_P( cgemmGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + // specifies alpha value + T alpha = std::get<6>(GetParam()); + // specifies beta value + T beta = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + gtint_t ldc_inc = std::get<10>(GetParam()); + // Set the threshold for the errors: + + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +/********************************************************************/ +/* Testing SUP and Native implementation of cgemm API */ +/********************************************************************/ +/************************** SCALM************************************/ +/* Scaling of C matrix for below conditions */ +/* 1. When alpha is zero */ +/* 2. When Matrix A or Matrix B has zero dimension */ +/* Scale Matrix C by Beta and return */ +/********************************************************************/ +/************************** SUP *************************************/ +/* Current SUP implmentation does not support below parameters */ +/* 1. General Stride */ +/* 2. Conjugate */ +/* 3. Input dimensions greater than below thresholds */ +/* m > 380 || n > 256 || k > 220 */ +/* SUP implementations is suitable for Skinny Matrices */ +/* List of API's: */ +/* 1. bli_cgemmsup_rv_zen_asm_3x8m: M preferred kernel */ +/* 2. bli_cgemmsup_rv_zen_asm_3x8n: N preferred kernel */ +/********************************************************************/ +/************************** NATIVE***********************************/ +/* When SUP method does not support given input arguments, */ +/* Native implmentation will be invoked, it is well suited for */ +/* square, large sizes */ +/* API Name: bli_cgemm_haswell_asm_3x8 */ +/********************************************************************/ + +INSTANTIATE_TEST_SUITE_P( + Alpha_zero, + cgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(300), gtint_t(17)), // m + ::testing::Values(gtint_t(200), gtint_t(18)), // n + ::testing::Values(gtint_t(150), gtint_t(19)), // k + ::testing::Values(scomplex{0.0, 0.0}), // alpha + ::testing::Values(scomplex{12.9, 12.3}, scomplex{0.0, 1.9}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{5.2, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix, + cgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(300), gtint_t(320)), // m + ::testing::Values(gtint_t(200), gtint_t(220)), // n + ::testing::Values(gtint_t(150), gtint_t(160)), // k + ::testing::Values(scomplex{-1.0, -2.0}), // alpha + ::testing::Values(scomplex{12.0, 2.3}), // beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_Alpha_Beta, + cgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(300), gtint_t(304)), // m + ::testing::Values(gtint_t(200), gtint_t(209)), // n + ::testing::Values(gtint_t(150)), // k + ::testing::Values(scomplex{0.0, -30.0}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{5.0, 0.0}), // alpha + ::testing::Values(scomplex{0.0, 1.3}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{5.0, 0.0}, scomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(6)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix, + cgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(400), gtint_t(700)), // m + ::testing::Values(gtint_t(380), gtint_t(1000)), // n + ::testing::Values(gtint_t(270), gtint_t(280)), // k + ::testing::Values(scomplex{1.5, 3.5}), // alpha + ::testing::Values(scomplex{2.0, 4.1}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_Alpha_Beta, + cgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(400), gtint_t(700)), // m + ::testing::Values(gtint_t(380), gtint_t(1000)), // n + ::testing::Values(gtint_t(270)), // k + ::testing::Values(scomplex{0.0, -10.0}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{2.0, 0.0}), // alpha + ::testing::Values(scomplex{0.0, 3.4}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{3.3, 0.0}, scomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp deleted file mode 100644 index 5043dc44a7..0000000000 --- a/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp +++ /dev/null @@ -1,152 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemm.h" - -class CGemmTest : - public ::testing::TestWithParam> {}; - -TEST_P(CGemmTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether matrix b is n,c,t,h - char transb = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // matrix size k - gtint_t k = std::get<5>(GetParam()); - // specifies alpha value - T alpha = std::get<6>(GetParam()); - // specifies beta value - T beta = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - gtint_t ldc_inc = std::get<10>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); -} - -class CGemmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - scomplex alpha = std::get<6>(str.param); - scomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - CGemmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','c','t'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Range(gtint_t(10), gtint_t(31), 10), // k - ::testing::Values(scomplex{2.0,-1.0}), // alpha - ::testing::Values(scomplex{1.0,2.0}), // beta - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c - ), - ::CGemmTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_evt.cpp b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_evt.cpp new file mode 100644 index 0000000000..b0baa1b79c --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_evt.cpp @@ -0,0 +1,438 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class dgemmEVT : + public ::testing::TestWithParam> {}; + +TEST_P( dgemmEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + + gtint_t ai = std::get<6>(GetParam()); + gtint_t aj = std::get<7>(GetParam()); + T aex = std::get<8>(GetParam()); + + gtint_t bi = std::get<9>(GetParam()); + gtint_t bj = std::get<10>(GetParam()); + T bex = std::get<11>(GetParam()); + + gtint_t ci = std::get<12>(GetParam()); + gtint_t cj = std::get<13>(GetParam()); + T cex = std::get<14>(GetParam()); + + // specifies alpha value + T alpha = std::get<15>(GetParam()); + // specifies beta value + T beta = std::get<16>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<17>(GetParam()); + gtint_t ldb_inc = std::get<18>(GetParam()); + gtint_t ldc_inc = std::get<19>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); +} + +/* + It contains both the exception value testing(EVT) and the + positive accuracy testing of the bli_DGEMM_4x4_avx2_k1_nn( ... ) computational + kernel. This kernel is invoked from the BLAS layer, and inputs are given + in a manner so as to avoid the other code-paths and test only the required + kernel. + +*/ + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Exception value testing(on matrices) + +/* + For the bli_DGEMM_8x6_avx2_k1_nn & bli_DGEMM_24x8_avx512_k1_nn kernel, the main and fringe dimensions are as follows: + For m : Main = { 8, 24 }, fringe = { 7 to 1, 23 to 1 } + For n : Main = { 6, 8 }, fringe = { 4 to 1, 7 to 1 } + + Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched + separately, since if m/n is 1, the inputs are redirected to ZGEMV. + +*/ + +// Testing for the main loop case for m and n +// The exception values are induced in load and broadcast +INSTANTIATE_TEST_SUITE_P( + K1_transA_N_transB_N_main, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(8),gtint_t(20)), // m + ::testing::Values(gtint_t(6),gtint_t(8)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(double(-2.2)), // alpha + ::testing::Values(double(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +// Testing the fringe cases +// Fringe case along both m and n. +INSTANTIATE_TEST_SUITE_P( + DISABLED_K1_transA_N_transB_N_fringe, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(13), gtint_t(24)), // m + ::testing::Values(gtint_t(2), gtint_t(5), gtint_t(8)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(double(NaN), double(Inf), double(-Inf)), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(1)), // bj + ::testing::Values(double(NaN), double(Inf), double(-Inf)), // bexval + ::testing::Values(gtint_t(1)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(double(NaN), double(Inf), double(-Inf)), // cexval + ::testing::Values(double(-2.2)), // alpha + ::testing::Values(double(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +// Exception value testing(on alpha and beta) +// Alpha and beta are set to exception values +INSTANTIATE_TEST_SUITE_P( + K1_transA_N_transB_N_alpha_beta, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(15), gtint_t(24)), // m + ::testing::Values(gtint_t(2), gtint_t(11), gtint_t(8)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(double(0.0)), + ::testing::Values(double(NaN), double(Inf), double(-Inf)), // alpha + ::testing::Values(double(NaN), double(Inf), double(-Inf)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/********************************************************/ +/* Testing for small code paths */ +/* m,n,k is choosen such that small code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + SMALL_Matrix, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(gtint_t(4)), // m + ::testing::Values(gtint_t(4)), // n + ::testing::Values(gtint_t(10)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(double(-2.2)), // alpha + ::testing::Values(double(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/******************************************************/ +/* Testing for SUP code paths */ +/* m,n,k is choosen such that SUP code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/******************************************************/ +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(90)), // m + ::testing::Values(gtint_t(80)), // n + ::testing::Values(gtint_t(1080)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(double(3.6)), // alpha + ::testing::Values(double(-5.)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/*********************************************************/ +/* Testing for native code paths */ +/* m,n,k is choosen such that Native code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/*********************************************************/ +INSTANTIATE_TEST_SUITE_P( + Large_Matrix, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(1001)), // m + ::testing::Values(gtint_t(1001)), // n + ::testing::Values(gtint_t(260)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(double(-2.2)), // alpha + ::testing::Values(double(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/********************************************************/ +/* Testing for small & sup code paths */ +/* m,n,k is choosen such that small & sup code path */ +/* are covered. */ +/* Matrix A, B, C are filled valid integers or floats */ +/* Alpha and beta are assigned with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + alpha_beta, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(14), gtint_t(100)), // m + ::testing::Values(gtint_t(10), gtint_t(90)), // n + ::testing::Values(gtint_t(20), gtint_t(1005)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(double(0.0)), + ::testing::Values(NaN), //Failures , Inf, -Inf), // alpha + ::testing::Values(NaN, Inf, -Inf), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/********************************************************/ +/* Testing for Native code paths */ +/* m,n,k is choosen such that nat code path are covered */ +/* Matrix A, B, C are filled valid integers or floats */ +/* Alpha and beta are assigned with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_alpha_beta, + dgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(1001)), // m + ::testing::Values(gtint_t(1001)), // n + ::testing::Values(gtint_t(260)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(double(0.0)), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(double(0.0)), + ::testing::Values(NaN), //Failures , Inf, -Inf), // alpha + ::testing::Values(NaN, Inf, -Inf), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_generic.cpp new file mode 100644 index 0000000000..b1eff156ac --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_generic.cpp @@ -0,0 +1,225 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class dgemmGeneric : + public ::testing::TestWithParam> {}; + + +//matrix storage format, transA, transB, m, n, k, alpha, beta, lda, ldb, ldc +TEST_P( dgemmGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + // specifies alpha value + T alpha = std::get<6>(GetParam()); + // specifies beta value + T beta = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + gtint_t ldc_inc = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + //thresh = (15*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + expect_dgemm_k1_path, + dgemmGeneric, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(3, 17, 103, 178), // m + ::testing::Values(2, 26, 79), // n + ::testing::Values(1), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 3), // increment to the leading dim of a + ::testing::Values(0, 3), // increment to the leading dim of b + ::testing::Values(0, 3) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +//----------------------------- bli_dgemm_tiny kernel ------------------------------------ +INSTANTIATE_TEST_SUITE_P( + expect_dgemm_tiny_path, + dgemmGeneric, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(3, 81, 138), // m + ::testing::Values(2, 35, 100), // n + ::testing::Values(5, 12, 24), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 3), // increment to the leading dim of a + ::testing::Values(0, 3), // increment to the leading dim of b + ::testing::Values(0, 3) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +//----------------------------- dgemm_small kernel ------------------------------------ + + +// Tests both bli_dgemm_small and bli_dgemm_small_At +INSTANTIATE_TEST_SUITE_P( + expect_dgemm_small_path, + dgemmGeneric, + ::testing::Combine( + // Test both storage types + ::testing::Values('c'), // storage format + // Covers all possible combinations of storage schemes + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(5, 19, 32, 44), // m + ::testing::Values(25, 27, 32), // n + // k-unroll factor = KR = 1 + ::testing::Values(5, 17, 24), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 3), // increment to the leading dim of a + ::testing::Values(0, 3), // increment to the leading dim of b + ::testing::Values(0, 3) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +// ----------------------------- SUP implementation -------------------------------------- +INSTANTIATE_TEST_SUITE_P( + expect_dgemm_sup_path, + dgemmGeneric, + ::testing::Combine( + // Storage of A and B is handled by packing + ::testing::Values('c'), // storage format + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1002, 1377), // m + ::testing::Values(453, 567), // n + ::testing::Values(105, 124), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 3), // increment to the leading dim of a + ::testing::Values(0, 3), // increment to the leading dim of b + ::testing::Values(0, 3) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +// ----------------------------- Native implementation -------------------------------------- +INSTANTIATE_TEST_SUITE_P( + expect_dgemm_native_path, + dgemmGeneric, + ::testing::Combine( + // Storage of A and B is handled by packing + ::testing::Values('c'), // storage format + // Covers vectorized section of 8xk and 6xk pack kernels for both storage formats + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(5017, 5061), // m + ::testing::Values(709, 5417), // n + ::testing::Values(515, 604), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 3), // increment to the leading dim of a + ::testing::Values(0, 3), // increment to the leading dim of b + ::testing::Values(0, 3) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_underflow_overflow.cpp b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_underflow_overflow.cpp new file mode 100644 index 0000000000..cc4f391b27 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/dgemm/dgemm_underflow_overflow.cpp @@ -0,0 +1,432 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class dgemmUOT : + public ::testing::TestWithParam> {}; + +TEST_P( dgemmUOT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,t + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,t + char transb = std::get<2>(GetParam()); + // over_under denotes whether overflow or underflow is to be tested + gtint_t over_under = std::get<3>(GetParam()); + // input_range denotes the range of values that would be used to populate the matrices + gtint_t input_range = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // matrix size k + gtint_t k = std::get<7>(GetParam()); + // specifies alpha value + T alpha = std::get<8>(GetParam()); + // specifies beta value + T beta = std::get<9>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<10>(GetParam()); + gtint_t ldb_inc = std::get<11>(GetParam()); + gtint_t ldc_inc = std::get<12>(GetParam()); + + // ai, aj, bi, bj are the indices where overflow/underflow values need to be inserted + gtint_t ai = std::get<13>(GetParam()); + gtint_t aj = std::get<14>(GetParam()); + gtint_t bi = std::get<15>(GetParam()); + gtint_t bj = std::get<16>(GetParam()); + + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, over_under, input_range, m, n, k, lda_inc, ldb_inc, ldc_inc, ai, aj, bi, bj, alpha, beta, thresh ); + +} + +/* + Tests for Overflow + + An Overflow condition occurs when the result of an operation or computation is larger than the + maximum representable floating point value. For double precision floating points, the largest + representable number is + DBL_MAX = 1.7976931348623158e+308 + + This test populates matrices with values close to DBL_MAX so that the subsequent operations lead + to values larger than DBL_MAX and hence causes a floating point overflow. + + The argument over_under is used to indicate whether the test is an overflow or an underflow test. + over_under = 0 indicates an overflow test + + The argument input_range is used to choose the range of values used to populate input matrices + input_range = -1 for values < DBL_MAX + input_range = 0 for values close to DBL_MAX + input_range = 1 for values > DBL_MAX +*/ + +/* Overflow test for values much less than DBL_MAX */ +INSTANTIATE_TEST_SUITE_P( + overflow_within_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(0), // over_under = 0 for overflow + ::testing::Values(-1), // input_range = -1 to test values less than DBL_MAX + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(3), // increment to the leading dim of a + ::testing::Values(3), // increment to the leading dim of b + ::testing::Values(3), // increment to the leading dim of c + + ::testing::Values(100), // ai + ::testing::Values(120), // aj + ::testing::Values(140), // bi + ::testing::Values(110) // bj + ), + ::gemmOUTPrint() + ); + +/* Overflow test for values close to DBL_MAX */ +INSTANTIATE_TEST_SUITE_P( + overflow_close_to_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(0), // over_under = 0 for overflow + ::testing::Values(0), // input_range = 0 to test values close to DBL_MAX + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(0), // increment to the leading dim of a + ::testing::Values(0), // increment to the leading dim of b + ::testing::Values(0), // increment to the leading dim of c + + ::testing::Values(110), // ai + ::testing::Values(130), // aj + ::testing::Values(140), // bi + ::testing::Values(120) // bj + ), + ::gemmOUTPrint() + ); + + +/* Overflow test for values close to DBL_MAX and aplha = 0*/ +INSTANTIATE_TEST_SUITE_P( + overflow_close_to_limit_alpha0, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(0), // over_under = 0 for overflow + ::testing::Values(0), // input_range = 0 to test values close to DBL_MAX + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values(0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(5), // increment to the leading dim of a + ::testing::Values(5), // increment to the leading dim of b + ::testing::Values(5), // increment to the leading dim of c + + ::testing::Values(108), // ai + ::testing::Values(122), // aj + ::testing::Values(145), // bi + ::testing::Values(108) // bj + ), + ::gemmOUTPrint() + ); + +/* Overflow test for values larger than DBL_MAX */ +INSTANTIATE_TEST_SUITE_P( + overflow_beyond_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(0), // over_under = 0 for overflow + ::testing::Values(1), // input_range = 1 to test values larger than DBL_MAX + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(0), // increment to the leading dim of a + ::testing::Values(0), // increment to the leading dim of b + ::testing::Values(0), // increment to the leading dim of c + + ::testing::Values(110), // ai + ::testing::Values(140), // aj + ::testing::Values(130), // bi + ::testing::Values(100) // bj + ), + ::gemmOUTPrint() + ); + + +/* + Tests for Underflow + + An underflow occurs when the result of an operation or a computation is smaller than the + smallest representable floating point number. For double-precision floating points, + the smallest representable number is + DBL_MIN = 2.2250738585072014e-308 + + This test populates matrices with values close to DBL_MIN so that the subsequent operations + lead to values smaller than DBL_MIN and hence results in a floating point underflow. + + The argument over_under is used to indicate whether a test is an overflow or an underflow test. + over_under=1 indicates an underflow test + + The argument input_range is used to choose the range of values used to populate input matrices + input_range = -1 for values > DBL_MIN + input_range = 0 for values close to DBL_MIN + input_range = 1 for values < DBL_MIN + +*/ + +/* Underflow test for values larger than DBL_MIN */ +INSTANTIATE_TEST_SUITE_P( + underflow_within_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(1), // over_under = 1 for underflow + ::testing::Values(-1), // input_range = -1 to test values larger than DBL_MIN + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(3), // increment to the leading dim of a + ::testing::Values(3), // increment to the leading dim of b + ::testing::Values(3), // increment to the leading dim of c + + ::testing::Values(100), // ai + ::testing::Values(120), // aj + ::testing::Values(140), // bi + ::testing::Values(110) // bj + ), + ::gemmOUTPrint() + ); + +/* Underflow test for values close to DBL_MIN */ +INSTANTIATE_TEST_SUITE_P( + underflow_close_to_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(1), // over_under = 1 for underflow + ::testing::Values(0), // input_range = 0 to test values close to DBL_MIN + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(5), // increment to the leading dim of a + ::testing::Values(5), // increment to the leading dim of b + ::testing::Values(5), // increment to the leading dim of c + + ::testing::Values(101), // ai + ::testing::Values(118), // aj + ::testing::Values(132), // bi + ::testing::Values(110) // bj + ), + ::gemmOUTPrint() + ); + +/* Underflow test for values close to DBL_MIN and alpha = 0 */ +INSTANTIATE_TEST_SUITE_P( + underflow_close_to_limit_alpha0, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(1), // over_under = 1 for underflow + ::testing::Values(0), // input_range = 0 to test values close to DBL_MIN + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values(0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(0), // increment to the leading dim of a + ::testing::Values(0), // increment to the leading dim of b + ::testing::Values(0), // increment to the leading dim of c + + ::testing::Values(117), // ai + ::testing::Values(122), // aj + ::testing::Values(88), // bi + ::testing::Values(42) // bj + ), + ::gemmOUTPrint() + ); + + + +/* Underflow test for values smaller than DBL_MIN */ +INSTANTIATE_TEST_SUITE_P( + underflow_beyond_limit, + dgemmUOT, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(1), // over_under = 1 for underflow + ::testing::Values(1), // input_range = 1 to test values smaller than DBL_MIN + ::testing::Values(120, 256, 512), // m + + ::testing::Values(144, 237, 680), // n + + ::testing::Values(128, 557, 680), // k + // No condition based on alpha + ::testing::Values(-1.0), // alpha + // No condition based on beta + ::testing::Values(-1.0), // beta + ::testing::Values(3), // increment to the leading dim of a + ::testing::Values(3), // increment to the leading dim of b + ::testing::Values(3), // increment to the leading dim of c + + ::testing::Values(44), // ai + ::testing::Values(135), // aj + ::testing::Values(100), // bi + ::testing::Values(105) // bj + ), + ::gemmOUTPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp deleted file mode 100644 index 8d07668cc4..0000000000 --- a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp +++ /dev/null @@ -1,310 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemm.h" - -class DGemmTest : - public ::testing::TestWithParam> {}; - -TEST_P(DGemmTest, RandomData) -{ - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether matrix b is n,c,t,h - char transb = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // matrix size k - gtint_t k = std::get<5>(GetParam()); - // specifies alpha value - T alpha = std::get<6>(GetParam()); - // specifies beta value - T beta = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - gtint_t ldc_inc = std::get<10>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); -} - -class DGemmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - double alpha = std::get<6>(str.param); - double beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - DGemmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','t'), // transa - ::testing::Values('n','t'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Range(gtint_t(10), gtint_t(31), 10), // k - ::testing::Values( 1.0, -2.0), // alpha - ::testing::Values(-1.0, 1.0), // beta - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - - -// Tests 5 loops -INSTANTIATE_TEST_SUITE_P( - tiny_dgemm_kernel, - DGemmTest, - ::testing::Combine( - // No condition based on storage scheme of matrices - ::testing::Values('c'), // storage format - // No conditions based on trans of matrices - ::testing::Values('n', 't'), // transa - ::testing::Values('n', 't'), // transb - - ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m - - ::testing::Values(8, 48, 72, 144, 237), // n - - ::testing::Values(16, 24, 48, 64, 128, 557), // k - // No condition based on alpha - ::testing::Values( -1.0), // alpha - // No condition based on betaa - ::testing::Values(-1.0), // beta - ::testing::Values(0,3), // increment to the leading dim of a - ::testing::Values(0,3), // increment to the leading dim of b - ::testing::Values(0,3) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - -//zero beta test case -INSTANTIATE_TEST_SUITE_P( - zero_beta, - DGemmTest, - ::testing::Combine( - // No condition based on storage scheme of matrices - ::testing::Values('c'), // storage format - // No conditions based on trans of matrices - ::testing::Values('n', 't'), // transa - ::testing::Values('n', 't'), // transb - - ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m - - ::testing::Values(8, 48, 72, 144, 237), // n - - ::testing::Values(16, 24, 48, 64, 128, 557), // k - - ::testing::Values( -1.0), // alpha - ::testing::Values(0.0), // beta - ::testing::Values(0,3), // increment to the leading dim of a - ::testing::Values(0,3), // increment to the leading dim of b - ::testing::Values(0,3) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - -//zero alpha test case -INSTANTIATE_TEST_SUITE_P( - zero_alpha, - DGemmTest, - ::testing::Combine( - // No condition based on storage scheme of matrices - ::testing::Values('c'), // storage format - // No conditions based on trans of matrices - ::testing::Values('n', 't'), // transa - ::testing::Values('n', 't'), // transb - - ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m - - ::testing::Values(8, 48, 72, 144, 237), // n - - ::testing::Values(16, 24, 48, 64, 128, 557), // k - - ::testing::Values( 0.0), // alpha - ::testing::Values(-1.0), // beta - ::testing::Values(0,3), // increment to the leading dim of a - ::testing::Values(0,3), // increment to the leading dim of b - ::testing::Values(0,3) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - -//unit beta test case -INSTANTIATE_TEST_SUITE_P( - unit_beta, - DGemmTest, - ::testing::Combine( - // No condition based on storage scheme of matrices - ::testing::Values('c'), // storage format - // No conditions based on trans of matrices - ::testing::Values('n', 't'), // transa - ::testing::Values('n', 't'), // transb - - ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m - - ::testing::Values(8, 48, 72, 144, 237), // n - - ::testing::Values(16, 24, 48, 64, 128, 557), // k - - ::testing::Values( -1.0), // alpha - ::testing::Values(1.0), // beta - ::testing::Values(0,3), // increment to the leading dim of a - ::testing::Values(0,3), // increment to the leading dim of b - ::testing::Values(0,3) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - -// Covers all corner cases of tiny dgemm kernel -INSTANTIATE_TEST_SUITE_P( - tiny_edge_kernels, - DGemmTest, - ::testing::Combine( - // To test col storage of C - // Storage of A and B is handled by packing - ::testing::Values('c'), // storage format - // Tests scalar code of 8xk and 6xk pack kernels for both storage formats - ::testing::Values('n','t'), // transa - ::testing::Values('n','t'), // transb - - ::testing::Range(gtint_t(1), gtint_t(23), 1), // m - ::testing::Range(gtint_t(1), gtint_t(7), 1), // n - - ::testing::Values(24), // k - // No condition based on alpha - ::testing::Values( -1.0, 1.0), // alpha - // checks for beta-zero and beta non-zero cases - ::testing::Values(0.0, 1.0, -1.0), // beta - ::testing::Values(23), // increment to the leading dim of a - ::testing::Values(23), // increment to the leading dim of b - ::testing::Values(23) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); - - -//m = 0, n = 0 k = 0 testcase -INSTANTIATE_TEST_SUITE_P( - mnkzero, - DGemmTest, - ::testing::Combine( - // No condition based on storage scheme of matrices - ::testing::Values('c'), // storage format - // No conditions based on trans of matrices - ::testing::Values('n', 't'), // transa - ::testing::Values('n', 't'), // transb - - ::testing::Values(0, 8, 24), // m - - ::testing::Values(0, 6, 8), // n - - ::testing::Values(3), // k - - ::testing::Values( -1.0), // alpha - ::testing::Values(1.0), // beta - ::testing::Values(0,3), // increment to the leading dim of a - ::testing::Values(0,3), // increment to the leading dim of b - ::testing::Values(0,3) // increment to the leading dim of c - ), - ::DGemmTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm/gemm.h b/gtestsuite/testsuite/level3/gemm/gemm.h index 907f078848..23b59a2bb6 100644 --- a/gtestsuite/testsuite/level3/gemm/gemm.h +++ b/gtestsuite/testsuite/level3/gemm/gemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -81,6 +82,22 @@ static void gemm_(char transa, char transb, gtint_t m, gtint_t n, gtint_t k, T* throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_()."); } +template +static void gemm_blis_impl(char transa, char transb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + sgemm_blis_impl( &transa, &transb, &m, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + dgemm_blis_impl( &transa, &transb, &m, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + cgemm_blis_impl( &transa, &transb, &m, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zgemm_blis_impl( &transa, &transb, &m, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_blis_impl()."); +} + template static void cblas_gemm(char storage, char transa, char transb, gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, @@ -151,12 +168,54 @@ template static void gemm( char storage, char transa, char transb, gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + transa = static_cast(std::toupper(static_cast(transa))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char transa_cpy = transa; + char transb_cpy = transb; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, m, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, k, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) gemm_( transa, transb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/gemm.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + gemm_blis_impl( transa, transb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemm.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_gemm( storage, transa, transb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED @@ -164,4 +223,44 @@ static void gemm( char storage, char transa, char transb, gtint_t m, gtint_t n, #else throw std::runtime_error("Error in testsuite/level3/gemm.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, m, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, m, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, k, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, k, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_evt.cpp b/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_evt.cpp new file mode 100644 index 0000000000..caa1932d10 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_evt.cpp @@ -0,0 +1,312 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class sgemmEVT : + public ::testing::TestWithParam> {}; +TEST_P( sgemmEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + gtint_t ai, aj, bi, bj, ci, cj; + T aex, bex, cex; + ai = std::get<6>(GetParam()); + aj = std::get<7>(GetParam()); + aex = std::get<8>(GetParam()); + bi = std::get<9>(GetParam()); + bj = std::get<10>(GetParam()); + bex = std::get<11>(GetParam()); + ci = std::get<12>(GetParam()); + cj = std::get<13>(GetParam()); + cex = std::get<14>(GetParam()); + // specifies alpha value + T alpha = std::get<15>(GetParam()); + // specifies beta value + T beta = std::get<16>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<17>(GetParam()); + gtint_t ldb_inc = std::get<18>(GetParam()); + gtint_t ldc_inc = std::get<19>(GetParam()); + + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); +} +/* + It contains the exception value testing(EVT). +*/ +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); +// Exception value testing(on matrices) + + +/********************************************************/ +/* Testing for small code paths */ +/* m,n,k is choosen such that small code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + DISABLED_SMALL_Matrix, + sgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(5, 19, 35, 48), // m + ::testing::Values(13, 45), // n + ::testing::Values(gtint_t(2), gtint_t(25)), // k + ::testing::Values(gtint_t(1), gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(0), gtint_t(2)), // ci + ::testing::Values(gtint_t(1), gtint_t(3)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(float(-2.2)), // alpha + ::testing::Values(float(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); +/******************************************************/ +/* Testing for SUP code paths */ +/* m,n,k is choosen such that SUP code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/******************************************************/ +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix, + sgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(1002, 1327), // m + ::testing::Values(453, 531), // n + ::testing::Values(gtint_t(250), gtint_t(261)), // k + ::testing::Values(gtint_t(1), gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(0), gtint_t(2)), // ci + ::testing::Values(gtint_t(1), gtint_t(3)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(float(3.6)), // alpha + ::testing::Values(float(-5.1)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); +/*********************************************************/ +/* Testing for native code paths */ +/* m,n,k is choosen such that Native code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/*********************************************************/ +INSTANTIATE_TEST_SUITE_P( + Large_Matrix, + sgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(1001)), // m + ::testing::Values(gtint_t(1001)), // n + ::testing::Values(gtint_t(260)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(NaN, Inf, -Inf), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(NaN, Inf, -Inf), // bexval + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(NaN, Inf, -Inf), // cexval + ::testing::Values(float(-2.2)), // alpha + ::testing::Values(float(1.2)), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); +/********************************************************/ +/* Testing for small & sup code paths */ +/* m,n,k is choosen such that small & sup code path */ +/* are covered. */ +/* Matrix A, B, C are filled valid integers or floats */ +/* Alpha and beta are assigned with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + alpha_beta, + sgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(14), gtint_t(100)), // m + ::testing::Values(gtint_t(10), gtint_t(90)), // n + ::testing::Values(gtint_t(20), gtint_t(105)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(float(0.0)), + ::testing::Values(NaN), //Failures , Inf, -Inf), // alpha + ::testing::Values(NaN, Inf, -Inf), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); +/********************************************************/ +/* Testing for Native code paths */ +/* m,n,k is choosen such that nat code path are covered */ +/* Matrix A, B, C are filled valid integers or floats */ +/* Alpha and beta are assigned with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_alpha_beta, + sgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(1001)), // m + ::testing::Values(gtint_t(1001)), // n + ::testing::Values(gtint_t(260)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(float(0.0)), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(float(0.0)), + ::testing::Values(NaN), //Failures , Inf, -Inf), // alpha + ::testing::Values(NaN), //Failure Inf, -Inf), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_generic.cpp new file mode 100644 index 0000000000..5e813582a5 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/sgemm/sgemm_generic.cpp @@ -0,0 +1,197 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class sgemmGeneric : + public ::testing::TestWithParam> {}; + +//matrix storage format, transA, transB, m, n, k, alpha, beta, lda, ldb, ldc + +TEST_P( sgemmGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + // specifies alpha value + T alpha = std::get<6>(GetParam()); + // specifies beta value + T beta = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + gtint_t ldc_inc = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + //thresh = (24*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + expect_sgemv_path, + sgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(gtint_t(1), gtint_t(4), gtint_t(6)), // m + ::testing::Values(gtint_t(1), gtint_t(5), gtint_t(6)), // n + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(6)), // k + ::testing::Values(5.3, -1.0, 1.0), // alpha + ::testing::Values(6.4, 1.0, -1.0, 0.0), // beta + ::testing::Values(0, 13), // increment to the leading dim of a + ::testing::Values(0, 15), // increment to the leading dim of b + ::testing::Values(0, 17) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +//----------------------------- sgemmGeneric_small kernel ------------------------------------ +INSTANTIATE_TEST_SUITE_P( + expect_sgemmGeneric_small_path, + sgemmGeneric, + ::testing::Combine( + // Test both storage types + ::testing::Values('c'), // storage format + // Covers all possible combinations of storage schemes + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(5, 20, 32, 44), // m + ::testing::Values(25, 37, 42), // n + // k-unroll factor = KR = 1 + ::testing::Values(2, 13, 24), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 13), // increment to the leading dim of a + ::testing::Values(0, 15), // increment to the leading dim of b + ::testing::Values(0, 17) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +// ----------------------------- SUP implementation -------------------------------------- +INSTANTIATE_TEST_SUITE_P( + expect_sgemmGeneric_sup_path, + sgemmGeneric, + ::testing::Combine( + // Storage of A and B is handled by packing + ::testing::Values('c'), // storage format + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1002, 1083, 1378), // m + ::testing::Values(453, 504, 567), // n + ::testing::Values(250, 155, 260), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 13), // increment to the leading dim of a + ::testing::Values(0, 15), // increment to the leading dim of b + ::testing::Values(0, 17) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +// ----------------------------- Native implementation -------------------------------------- +INSTANTIATE_TEST_SUITE_P( + expect_sgemmGeneric_native_path, + sgemmGeneric, + ::testing::Combine( + // Storage of A and B is handled by packing + ::testing::Values('c'), // storage format + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(5017, 5327), // m + ::testing::Values(1709, 5417), // n + ::testing::Values(515, 604), // k + // No condition based on alpha + ::testing::Values(0.0, -1.0, 1.0, 1.7), // alpha + // No condition based on beta + ::testing::Values(0.0, -1.0, 1.0, 2.3), // beta + ::testing::Values(0, 13), // increment to the leading dim of a + ::testing::Values(0, 15), // increment to the leading dim of b + ::testing::Values(0, 17) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); \ No newline at end of file diff --git a/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp deleted file mode 100644 index 2adbe2968a..0000000000 --- a/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp +++ /dev/null @@ -1,199 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemm.h" - -class SGemmTest : - public ::testing::TestWithParam> {}; - -TEST_P(SGemmTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether matrix b is n,c,t,h - char transb = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // matrix size k - gtint_t k = std::get<5>(GetParam()); - // specifies alpha value - T alpha = std::get<6>(GetParam()); - // specifies beta value - T beta = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - gtint_t ldc_inc = std::get<10>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); -} - -class SGemmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - float alpha = std::get<6>(str.param); - float beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - sgemm_sup_10_30, - SGemmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','t'), // transa - ::testing::Values('n','t'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Range(gtint_t(10), gtint_t(31), 10), // k - ::testing::Values( 1.0, -2.0), // alpha - ::testing::Values(-1.0, 1.0), // beta - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of c - ), - ::SGemmTestPrint() - ); - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - sgemm_sup_alpha_beta, - SGemmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','t'), // transa - ::testing::Values('n','t'), // transb - ::testing::Range(gtint_t(1), gtint_t(20), 1), // m - ::testing::Range(gtint_t(1), gtint_t(50), 1), // n - ::testing::Range(gtint_t(1), gtint_t(10), 1), // k - ::testing::Values(0.0, 1.0, -1.0, 5.3, -10.0), // alpha - ::testing::Values(0.0, 1.0, -1.0, 6.4, -19.0), // beta - ::testing::Values(gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(7)) // increment to the leading dim of c - ), - ::SGemmTestPrint() - ); - - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - sgemm_sup_m_n_k_100, - SGemmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // transa - ::testing::Values('n'), // transb - ::testing::Range(gtint_t(1), gtint_t(20), 1), // m - ::testing::Range(gtint_t(1), gtint_t(50), 1), // n - ::testing::Range(gtint_t(1), gtint_t(20), 1), // k - ::testing::Values( -2.0), // alpha - ::testing::Values( 5.0), // beta - ::testing::Values(gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(7)) // increment to the leading dim of c - ), - ::SGemmTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm/test_gemm.h b/gtestsuite/testsuite/level3/gemm/test_gemm.h index 147bcdab50..f88348f65a 100644 --- a/gtestsuite/testsuite/level3/gemm/test_gemm.h +++ b/gtestsuite/testsuite/level3/gemm/test_gemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,6 +39,7 @@ #include "inc/check_error.h" #include #include +#include template void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, @@ -55,7 +56,14 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + std::vector c( testinghelpers::matsize( storage, 'n', m, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, c.data(), 'n', ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, c.data(), 'n', ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -75,7 +83,12 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( "c", storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } // Test body used for exception value testing, by inducing an exception value @@ -134,5 +147,265 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh, true ); + computediff( "c", storage, m, n, c.data(), c_ref.data(), ldc, thresh, true ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// Test body used for overflow and underflow checks +template +void test_gemm( char storage, char trnsa, char trnsb, gtint_t over_under, gtint_t input_range, + gtint_t m, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, + gtint_t ldc_inc, gtint_t ai, gtint_t aj, gtint_t bi, gtint_t bj, T alpha, + T beta, double thresh ) +{ + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + + //---------------------------------------------------------- + // Initialize matrices with random numbers + //---------------------------------------------------------- + std::vector a,b,c; + + /* + Testing for Overflow + ====================== + For double-precision floating point, the maximum representable number is + DBL_MAX = 1.7976931348623158e+308 + + Any value higher than DBL_MAX is considered to be an overflow. + + over_under=0 indicates Overflow testing + The input matrices are populated with 3 different value ranges based on input_range + + |****************************************************************| + | input_range | Expected Input | Expected Output | + |*************|*************************|************************| + | -1 | Values much less than | Exact floating point | + | | DBL_MAX | values | + |*************|*************************|************************| + | 0 | Values close to | Exact floating point | + | | DBL_MAX | values upto DBL_MAX | + | | | | + | | | +/-INF for values | + | | | higher than +/-DBL_MAX | + |*************|*************************|************************| + | 1 | Values much higher than | +/-INF for values | + | | DBL_MAX | higher than +/-DBL_MAX | + | | | | + ****************************************************************** + + Testing for Underflow + ======================== + For double-precision floating point, the minimum representable number is + DBL_MIN = 2.2250738585072014e-308 + + Any value lower than DBL_MIN is considered to be an underflow + + over_under=1 indicates Underflow testing + The input matrices are populated with 3 different value ranges based on input_range + + |******************************************************************| + | input_range | Expected Input | Expected Output | + |*************|**************************|*************************| + | -1 | Values much larger | Exact floating point | + | | than DBL_MIN | values | + |*************|**************************|*************************| + | 0 | Values close to | Exact floating point | + | | DBL_MIN | values upto DBL_MIN | + | | | | + | | | +0 for values | + | | | lower than DBL_MIN | + |*************|**************************|*************************| + | 1 | Values much smaller than | +0 for values | + | | DBL_MIN | smaller than +/-DBL_MIN | + | | | | + ******************************************************************** + + */ + a = testinghelpers::get_random_matrix( 5.5, 10.5, storage, trnsa, m, k, lda, 1, + testinghelpers::datagenerators::ElementType::FP ); + b = testinghelpers::get_random_matrix( 3.2, 5.6, storage, trnsb, k, n, ldb, 1, + testinghelpers::datagenerators::ElementType::FP ); + c = testinghelpers::get_random_matrix( -5, -2, storage, 'n', m, n, ldc, 1, + testinghelpers::datagenerators::ElementType::FP ); + /* + Based on the value of over_under, overflow/underflow values are inserted to the input matrices + at the indices passed as arguments. + */ + testinghelpers::set_overflow_underflow_mat( storage, trnsa, lda, ai, aj, a.data(), over_under, input_range); + testinghelpers::set_overflow_underflow_mat( storage, trnsb, lda, bi, bj, b.data(), over_under, input_range); + + std::vector c_ref(c); + + // Create a copy of c so that we can check reference results. + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemm( storage, trnsa, trnsb, m, n, k, &alpha, a.data(), lda, + b.data(), ldb, &beta, c.data(), ldc ); + + //---------------------------------------------------------- + // Call reference implementation. + //---------------------------------------------------------- + testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( "C", storage, m, n, c.data(), c_ref.data(), ldc, thresh, true ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class gemmGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char transb = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + T beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; + +template +class gemmEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char transb = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + gtint_t ai, aj, bi, bj, ci, cj; + T aex, bex, cex; + ai = std::get<6>(str.param); + aj = std::get<7>(str.param); + aex = std::get<8>(str.param); + + bi = std::get<9>(str.param); + bj = std::get<10>(str.param); + bex = std::get<11>(str.param); + + ci = std::get<12>(str.param); + cj = std::get<13>(str.param); + cex = std::get<14>(str.param); + + T alpha = std::get<15>(str.param); + T beta = std::get<16>(str.param); + gtint_t lda_inc = std::get<17>(str.param); + gtint_t ldb_inc = std::get<18>(str.param); + gtint_t ldc_inc = std::get<19>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name = str_name + "_A" + std::to_string(ai) + std::to_string(aj); + str_name = str_name + "_" + testinghelpers::get_value_string(aex); + str_name = str_name + "_B" + std::to_string(bi) + std::to_string(bj); + str_name = str_name + "_" + testinghelpers::get_value_string(bex); + str_name = str_name + "_C" + std::to_string(ci) + std::to_string(cj); + str_name = str_name + "_" + testinghelpers::get_value_string(cex); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; + +template +class gemmOUTPrint { + public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char transb = std::get<2>(str.param); + gtint_t over_under = std::get<3>(str.param); + gtint_t input_range = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + T alpha = std::get<8>(str.param); + T beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); + gtint_t ai = std::get<13>(str.param); + gtint_t aj = std::get<14>(str.param); + gtint_t bi = std::get<15>(str.param); + gtint_t bj = std::get<16>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + std::string over_under_str = ( over_under > 0) ? "underflow": "overflow"; + str_name = str_name + "_" + over_under_str; + std::string input_range_str = (input_range < 0) ? "within_limit": (input_range > 0) ? "beyond_limit" : "close_to_limit"; + str_name = str_name + "_" + input_range_str; + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name = str_name + "_A_" + std::to_string(ai) + "_" + std::to_string(aj); + str_name = str_name + "_B_" + std::to_string(bi) + "_" + std::to_string(bj); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_evt.cpp b/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_evt.cpp new file mode 100644 index 0000000000..45950e16b7 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_evt.cpp @@ -0,0 +1,444 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +using T = dcomplex; + +static float AOCL_NAN = std::numeric_limits::quiet_NaN(); +static float AOCL_INF = std::numeric_limits::infinity(); + +class zgemmEVT : + public ::testing::TestWithParam> {}; + +TEST_P( zgemmEVT, API ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + + gtint_t ai = std::get<6>(GetParam()); + gtint_t aj = std::get<7>(GetParam()); + T aex = std::get<8>(GetParam()); + + gtint_t bi = std::get<9>(GetParam()); + gtint_t bj = std::get<10>(GetParam()); + T bex = std::get<11>(GetParam()); + + gtint_t ci = std::get<12>(GetParam()); + gtint_t cj = std::get<13>(GetParam()); + T cex = std::get<14>(GetParam()); + + // specifies alpha value + T alpha = std::get<15>(GetParam()); + // specifies beta value + T beta = std::get<16>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<17>(GetParam()); + gtint_t ldb_inc = std::get<18>(GetParam()); + gtint_t ldc_inc = std::get<19>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); +} + +// Exception value testing(on matrices) + +/* + It contains both the exception value testing(EVT) and the + positive accuracy testing of the bli_ZGEMM_4x4_avx2_k1_nn( ... ) computational + kernel. This kernel is invoked from the BLAS layer, and inputs are given + in a manner so as to avoid the other code-paths and test only the required + kernel. + +*/ +/* + For the bli_ZGEMM_4x4_avx2_k1_nn kernel, the main and fringe dimensions are as follows: + For m : Main = { 4 }, fringe = { 2, 1 } + For n : Main = { 4 }, fringe = { 2, 1 } + + Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched + separately, since if m/n is 1, the inputs are redirected to ZGEMV. + +*/ + +// Testing for the main loop case for m and n +// The kernel uses 2 loads and 4 broadcasts. The exception values +// are induced at one index individually for each of the loads. +// They are also induced in the broadcast direction at two places. +INSTANTIATE_TEST_SUITE_P( + K1_transA_N_transB_N_main, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(4)), // m + ::testing::Values(gtint_t(4)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-2.2, 3.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{3.4, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.2, -2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{3.1, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +// Testing the fringe cases +// Fringe case minimum size is 2 along both m and n. +// Invloves only one load(AVX2 or (AVX2+SSE)). Thus, +// the exception values are induced at the first and second indices of the +// column vector A and row vector B. +INSTANTIATE_TEST_SUITE_P( + K1_transA_N_transB_N_fringe, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(3)), // m + ::testing::Values(gtint_t(2), gtint_t(3)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(1)), // bj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(1)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-2.2, 3.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{2.3, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.2, -2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{5.6, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +// Exception value testing(on alpha and beta) +// Alpha and beta are set to exception values +INSTANTIATE_TEST_SUITE_P( + K1_transA_N_transB_N_alphabeta, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(4)), // m + ::testing::Values(gtint_t(2), gtint_t(4)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // alpha + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/********************************************************/ +/* Testing for small code paths */ +/* m,n,k is choosen such that small code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + DISABLED_Small_Matrix, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(4)), // m + ::testing::Values(gtint_t(4)), // n + ::testing::Values(gtint_t(10)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ //Failures + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-2.2, 3.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{6.0, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.2, -2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{5.6, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/******************************************************/ +/* Testing for SUP code paths */ +/* m,n,k is choosen such that SUP code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/******************************************************/ +INSTANTIATE_TEST_SUITE_P( + DISABLED_Skinny_Matrix, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(gtint_t(90)), // m + ::testing::Values(gtint_t(80)), // n + ::testing::Values(gtint_t(1080)), // k + ::testing::Values(gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ //Failure + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(1)), // cj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{3.6, -1.0}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{34.0, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{-5.7, 1.2}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{3.0, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/*********************************************************/ +/* Testing for Native code paths */ +/* m,n,k is choosen such that Native code path is called */ +/* Matrix A, B, C are filled with Infs and Nans */ +/*********************************************************/ +INSTANTIATE_TEST_SUITE_P( + DISABLED_Large_Matrix, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(200)), // m + ::testing::Values(gtint_t(200)), // n + ::testing::Values(gtint_t(130)), // k + ::testing::Values(gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ //Failures + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(T{AOCL_NAN, 2.3}, /*T{AOCL_INF, 0.0},*/ + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // bexval + ::testing::Values(gtint_t(2)), // ci + ::testing::Values(gtint_t(3)), // cj + ::testing::Values(T{AOCL_NAN, 2.3}, T{AOCL_INF, 0.0}, + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // cexval + ::testing::Values(T{-2.2, 3.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{4.1, 0.0}, T{0.0, 1.0}), // alpha + ::testing::Values(T{1.2, -2.3}, T{0.0, 0.0}, + T{1.0, 0.0}, T{-1.0, 0.0}, + T{4.3, 0.0}, T{0.0, 1.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); + +/********************************************************/ +/* Testing for all code paths */ +/* m,n,k is choosen such that all code path are covered */ +/* Matrix A, B, C are filled valid integers or floats */ +/* Matrix A, B, C are filled with Infs and Nans */ +/********************************************************/ +INSTANTIATE_TEST_SUITE_P( + alpha_beta, + zgemmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values(gtint_t(14), gtint_t(200)), // m + ::testing::Values(gtint_t(10), gtint_t(300)), // n + ::testing::Values(gtint_t(20), gtint_t(1005)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(T{0.0, 0.0}), + ::testing::Values(T{AOCL_NAN, 2.3}, /* T{AOCL_INF, 0.0}, */ + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // alpha + ::testing::Values(T{AOCL_NAN, 2.3}, /* T{AOCL_INF, 0.0}, */ + T{3.4, AOCL_NAN}, T{AOCL_NAN, -AOCL_INF}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_generic.cpp new file mode 100644 index 0000000000..5f16f11b4c --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/zgemm/zgemm_generic.cpp @@ -0,0 +1,356 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/gemm/test_gemm.h" + +class zgemmGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( zgemmGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + // specifies alpha value + T alpha = std::get<6>(GetParam()); + // specifies beta value + T beta = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + gtint_t ldc_inc = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (7*k+3)*testinghelpers::getEpsilon(); + //thresh = (15*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +/********************************************************************/ +/* Blas interface testing as per the code sequence */ +/* Below API's will be invoked if input condition is satisified */ +/* List of API's - Input conditions */ +/* SCALM : alpha = 0 */ +/* GEMV : m = 1 or n = 1 */ +/* K1 : k = 1 & tranaA = 'n' & transB = 'n; */ +/* Small ST : ((m0*k0) <= 16384) || ((n0*k0) <= 16384))) */ +/* SUP AVX2 : (m & n & k) <= 128 */ +/* SUP AVX512 : (m & k) <= 128 & n <= 110 */ +/* Native : Default path, */ +/* : when none of the above API's are invoked */ +/********************************************************************/ +INSTANTIATE_TEST_SUITE_P( + SCALM, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n','c','t'), // transa + ::testing::Values('n','c','t'), // transb + ::testing::Values(gtint_t(10)), // m + ::testing::Values(gtint_t(10)), // n + ::testing::Values(gtint_t(10)), // k + ::testing::Values(dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{3.1, 15.9}, + dcomplex{0.0, 0.0}), //beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + GEMV_M1_N1, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(1)), // m + ::testing::Values(gtint_t(1)), // n + ::testing::Values(gtint_t(100), gtint_t(200)), // k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + GEMV_M1, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(1)), // m + ::testing::Values(gtint_t(2), gtint_t(89), gtint_t(197)), // n + ::testing::Values(gtint_t(100), gtint_t(200)), // k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + GEMV_N1, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(1), gtint_t(100), gtint_t(47)), // m + ::testing::Values(gtint_t(1)), // n + ::testing::Values(gtint_t(100), gtint_t(200)), // k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{3.1, -1.5}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.3, -2.9}, + dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +// Unit testing for bli_zgemm_4x4_avx2_k1_nn kernel +/* From the BLAS layer(post parameter checking), the inputs will be redirected to this kernel + if m != 1, n !=1 and k == 1 */ + +INSTANTIATE_TEST_SUITE_P( + K_1, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(9), gtint_t(16)), // m + ::testing::Values(gtint_t(2), gtint_t(7)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + SMALL_Matrix_ST, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(7), gtint_t(8)), // m + ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(7), gtint_t(8)), // n + ::testing::Values(gtint_t(2), gtint_t(4), gtint_t(10)), // k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0}, dcomplex{0, 1.0}, dcomplex{-1.0, -2.0}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0}, dcomplex{0, 1.0}, dcomplex{1.0, 2.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Skinny_Matrix_Trans_N, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(100), gtint_t(105)), // m + ::testing::Values(gtint_t(80), gtint_t(85)), // n + ::testing::Values(gtint_t(1000), gtint_t(1010)), // k + ::testing::Values(dcomplex{-1.0, -2.0}, dcomplex{0.0, -30.0}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{5.0, 0.0}), // alpha + ::testing::Values(dcomplex{12.0, 2.3}, dcomplex{0.0, 1.3}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{5.0, 0.0}), // beta + ::testing::Values(gtint_t(540)), // increment to the leading dim of a + ::testing::Values(gtint_t(940)), // increment to the leading dim of b + ::testing::Values(gtint_t(240)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + SKinny_Matrix_Trans_T, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('t'), // transa + ::testing::Values('t'), // transb + ::testing::Values(gtint_t(105)), // m + ::testing::Values(gtint_t(190)), // n + ::testing::Values(gtint_t(500)), // k + ::testing::Values(dcomplex{-1.8, -21.0}, dcomplex{0.0, -33.0}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{5.3, 0.0}), // alpha + ::testing::Values(dcomplex{1.8, 9.3}, dcomplex{0.0, 3.3}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{2.91, 0.0}, dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Large_Matrix_Trans_N_C_T, + zgemmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n', 'c', 't'), // transb + ::testing::Values(gtint_t(200)), // m + ::testing::Values(gtint_t(180)), // n + ::testing::Values(gtint_t(170)), // k + ::testing::Values(dcomplex{1.5, 3.5}, dcomplex{0.0, -10.0}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{2.0, 0.0}), // alpha + ::testing::Values(dcomplex{2.0, 4.1}, dcomplex{0.0, 3.4}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{3.3, 0.0}, dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(300)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(200)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(500)) // increment to the leading dim of c + ), + ::gemmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp b/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp deleted file mode 100644 index 3b0f05ab9b..0000000000 --- a/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp +++ /dev/null @@ -1,356 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -/* - The following file contains both the exception value testing(EVT) and the - positive accuracy testing of the bli_zgemm_4x4_avx2_k1_nn( ... ) computational - kernel. This kernel is invoked from the BLAS layer, and inputs are given - in a manner so as to avoid the other code-paths and test only the required - kernel. - -*/ - -#include -#include "test_gemm.h" - -class ZGemmEVTTest : - public ::testing::TestWithParam> {}; - -TEST_P(ZGemmEVTTest, Unit_Tester) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether matrix b is n,c,t,h - char transb = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // matrix size k - gtint_t k = std::get<5>(GetParam()); - - gtint_t ai, aj, bi, bj, ci, cj; - T aex, bex, cex; - ai = std::get<6>(GetParam()); - aj = std::get<7>(GetParam()); - aex = std::get<8>(GetParam()); - - bi = std::get<9>(GetParam()); - bj = std::get<10>(GetParam()); - bex = std::get<11>(GetParam()); - - ci = std::get<12>(GetParam()); - cj = std::get<13>(GetParam()); - cex = std::get<14>(GetParam()); - - // specifies alpha value - T alpha = std::get<15>(GetParam()); - // specifies beta value - T beta = std::get<16>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<17>(GetParam()); - gtint_t ldb_inc = std::get<18>(GetParam()); - gtint_t ldc_inc = std::get<19>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, - alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); -} - -// Helper classes for printing the test case parameters based on the instantiator -// These are mainly used to help with debugging, in case of failures - -// Utility to print the test-case in case of exception value on matrices -class ZGemmEVMatPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - gtint_t ai, aj, bi, bj, ci, cj; - dcomplex aex, bex, cex; - ai = std::get<6>(str.param); - aj = std::get<7>(str.param); - aex = std::get<8>(str.param); - - bi = std::get<9>(str.param); - bj = std::get<10>(str.param); - bex = std::get<11>(str.param); - - ci = std::get<12>(str.param); - cj = std::get<13>(str.param); - cex = std::get<14>(str.param); - - dcomplex alpha = std::get<15>(str.param); - dcomplex beta = std::get<16>(str.param); - gtint_t lda_inc = std::get<17>(str.param); - gtint_t ldb_inc = std::get<18>(str.param); - gtint_t ldc_inc = std::get<19>(str.param); - -#ifdef TEST_BLAS - std::string str_name = "zgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - str_name = str_name + "_A" + std::to_string(ai) + std::to_string(aj); - str_name = str_name + "_" + testinghelpers::get_value_string(aex); - str_name = str_name + "_B" + std::to_string(bi) + std::to_string(bj); - str_name = str_name + "_" + testinghelpers::get_value_string(bex); - str_name = str_name + "_C" + std::to_string(ci) + std::to_string(cj); - str_name = str_name + "_" + testinghelpers::get_value_string(cex); - str_name = str_name + "_a" + testinghelpers::get_value_string(alpha); - str_name = str_name + "_b" + testinghelpers::get_value_string(beta); - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -// Utility to print the test-case in case of exception value on matrices -class ZGemmEVAlphaBetaPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - - dcomplex alpha = std::get<15>(str.param); - dcomplex beta = std::get<16>(str.param); - gtint_t lda_inc = std::get<17>(str.param); - gtint_t ldb_inc = std::get<18>(str.param); - gtint_t ldc_inc = std::get<19>(str.param); - -#ifdef TEST_BLAS - std::string str_name = "zgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - str_name = str_name + "_a" + testinghelpers::get_value_string(alpha); - str_name = str_name + "_b" + testinghelpers::get_value_string(beta); - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -static double NaN = std::numeric_limits::quiet_NaN(); -static double Inf = std::numeric_limits::infinity(); - -// Exception value testing(on matrices) - -/* - For the bli_zgemm_4x4_avx2_k1_nn kernel, the main and fringe dimensions are as follows: - For m : Main = { 4 }, fringe = { 2, 1 } - For n : Main = { 4 }, fringe = { 2, 1 } - - Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched - separately, since if m/n is 1, the inputs are redirected to ZGEMV. - -*/ - -// Testing for the main loop case for m and n -// The kernel uses 2 loads and 4 broadcasts. The exception values -// are induced at one index individually for each of the loads. -// They are also induced in the broadcast direction at two places. -INSTANTIATE_TEST_SUITE_P( - bli_zgemm_4x4_avx2_k1_nn_evt_mat_main, - ZGemmEVTTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // transa - ::testing::Values('n'), // transb - ::testing::Values(gtint_t(4)), // m - ::testing::Values(gtint_t(4)), // n - ::testing::Values(gtint_t(1)), // k - ::testing::Values(gtint_t(1), gtint_t(3)), // ai - ::testing::Values(gtint_t(0)), // aj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval - ::testing::Values(gtint_t(0)), // bi - ::testing::Values(gtint_t(0), gtint_t(2)), // bj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval - ::testing::Values(gtint_t(0), gtint_t(2)), // ci - ::testing::Values(gtint_t(1), gtint_t(3)), // cj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval - ::testing::Values(dcomplex{-2.2, 3.3}), // alpha - ::testing::Values(dcomplex{1.2, -2.3}), // beta - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)) // increment to the leading dim of c - ), - ::ZGemmEVMatPrint() - ); - -// Testing the fringe cases -// Fringe case minimum size is 2 along both m and n. -// Invloves only one load(AVX2 or (AVX2+SSE)). Thus, -// the exception values are induced at the first and second indices of the -// column vector A and row vector B. -INSTANTIATE_TEST_SUITE_P( - bli_zgemm_4x4_avx2_k1_nn_evt_mat_fringe, - ZGemmEVTTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // transa - ::testing::Values('n'), // transb - ::testing::Values(gtint_t(2), gtint_t(3)), // m - ::testing::Values(gtint_t(2), gtint_t(3)), // n - ::testing::Values(gtint_t(1)), // k - ::testing::Values(gtint_t(0), gtint_t(1)), // ai - ::testing::Values(gtint_t(0)), // aj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval - ::testing::Values(gtint_t(0)), // bi - ::testing::Values(gtint_t(0), gtint_t(1)), // bj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval - ::testing::Values(gtint_t(0), gtint_t(1)), // ci - ::testing::Values(gtint_t(0), gtint_t(1)), // cj - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval - ::testing::Values(dcomplex{-2.2, 3.3}), // alpha - ::testing::Values(dcomplex{1.2, -2.3}), // beta - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)) // increment to the leading dim of c - ), - ::ZGemmEVMatPrint() - ); - -// Exception value testing(on alpha and beta) -// Alpha and beta are set to exception values -INSTANTIATE_TEST_SUITE_P( - bli_zgemm_4x4_avx2_k1_nn_evt_alphabeta, - ZGemmEVTTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // transa - ::testing::Values('n'), // transb - ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // m - ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // n - ::testing::Values(gtint_t(1)), // k - ::testing::Values(gtint_t(0)), // ai - ::testing::Values(gtint_t(0)), // aj - ::testing::Values(dcomplex{0.0, 0.0}), - ::testing::Values(gtint_t(0)), // bi - ::testing::Values(gtint_t(0)), // bj - ::testing::Values(dcomplex{0.0, 0.0}), - ::testing::Values(gtint_t(0)), // ci - ::testing::Values(gtint_t(0)), // cj - ::testing::Values(dcomplex{0.0, 0.0}), - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // alpha - ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, - dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // beta - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)) // increment to the leading dim of c - ), - ::ZGemmEVAlphaBetaPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp deleted file mode 100644 index 6bdb2d63e8..0000000000 --- a/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp +++ /dev/null @@ -1,179 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_gemm.h" - -class ZGemmAccTest : - public ::testing::TestWithParam> {}; - -TEST_P(ZGemmAccTest, Unit_Tester) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<1>(GetParam()); - // denotes whether matrix b is n,c,t,h - char transb = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); - // matrix size n - gtint_t n = std::get<4>(GetParam()); - // matrix size k - gtint_t k = std::get<5>(GetParam()); - // specifies alpha value - T alpha = std::get<6>(GetParam()); - // specifies beta value - T beta = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - gtint_t ldc_inc = std::get<10>(GetParam()); - - // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); -} - -class ZGemmAccPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - dcomplex alpha = std::get<6>(str.param); - dcomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zgemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zgemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zgemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);; - str_name = str_name + "_b" + testinghelpers::get_value_string(beta);; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - -// Unit testing for bli_zgemm_4x4_avx2_k1_nn kernel -/* From the BLAS layer(post parameter checking), the inputs will be redirected to this kernel - if m != 1, n !=1 and k == 1 */ - -INSTANTIATE_TEST_SUITE_P( - bli_zgemm_4x4_avx2_k1_nn, - ZGemmAccTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n'), // transa - ::testing::Values('n'), // transb - ::testing::Range(gtint_t(2), gtint_t(8), 1), // m - ::testing::Range(gtint_t(2), gtint_t(8), 1), // n - ::testing::Values(gtint_t(1)), // k - ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, - dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, - dcomplex{0.0, 0.0}), // alpha - ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, - dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, - dcomplex{0.0, 0.0}), // beta - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c - ), - ::ZGemmAccPrint() - ); - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ZGemmAccTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','c','t'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Range(gtint_t(10), gtint_t(31), 10), // k - ::testing::Values(dcomplex{2.0,-1.0}), // alpha - ::testing::Values(dcomplex{1.0,2.0}), // beta - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c - ), - ::ZGemmAccPrint() - ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp index a648f53bc1..7232fa4eb9 100644 --- a/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include #include "test_gemm_compute.h" -class DGemmComputeTest : +class dgemmComputeGeneric : public ::testing::TestWithParam> {}; -TEST_P(DGemmComputeTest, RandomData) +TEST_P( dgemmComputeGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -85,63 +85,32 @@ TEST_P(DGemmComputeTest, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - double intermediate = (double)m*n*k; - double thresh = 10*intermediate*testinghelpers::getEpsilon(); + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + //thresh = (7*k+1)*testinghelpers::getEpsilon(); + //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class DGemmComputeTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - char pka = std::get<3>(str.param); - char pkb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - gtint_t k = std::get<7>(str.param); - double alpha = std::get<8>(str.param); - double beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dgemm_compute_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dgemm_compute"; -#else //#elif TEST_BLIS_TYPED - // BLIS interface not yet implemented for pack and compute APIs. - std::string str_name = "blis_dgemm_compute"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + pka + pkb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - DGemmComputeTest, + dgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -158,15 +127,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::DGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); INSTANTIATE_TEST_SUITE_P( TinySizes, - DGemmComputeTest, + dgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -183,15 +152,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::DGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); INSTANTIATE_TEST_SUITE_P( DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes - DGemmComputeTest, + dgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -208,5 +177,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::DGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h index 1d168df634..41bdb0aec5 100644 --- a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -67,6 +68,7 @@ * @param[in] ldc specifies the leading dimension of cp. */ +#ifdef TEST_BLAS template static void gemm_compute_(char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) @@ -210,6 +212,7 @@ static void gemm_compute_(char transa, char transb, char packa, char packb, gtin bBuffer ); dgemm_compute_( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); bli_free_user( bBuffer ); } @@ -267,6 +270,209 @@ static void gemm_compute_(char transa, char transb, char packa, char packb, gtin else throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_compute_()."); } +#endif + +template +static void gemm_compute_blis_impl(char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + T unit_alpha = 1.0; + err_t err = BLIS_SUCCESS; + if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_blis_impl( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_blis_impl( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_blis_impl( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_blis_impl( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_blis_impl( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_blis_impl( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_blis_impl( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + sgemm_compute_blis_impl( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_blis_impl( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_blis_impl( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_blis_impl( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + sgemm_compute_blis_impl( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_blis_impl( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_blis_impl( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_blis_impl( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_blis_impl( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_blis_impl( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_blis_impl( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_blis_impl( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + dgemm_compute_blis_impl( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_blis_impl( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_blis_impl( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_blis_impl( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + dgemm_compute_blis_impl( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else + throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_compute_blis_impl()."); +} template static void cblas_gemm_compute(char storage, char transa, char transb, char pcka, char pckb, @@ -440,12 +646,58 @@ template static void gemm_compute( char storage, char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + transa = static_cast(std::toupper(static_cast(transa))); + transb = static_cast(std::toupper(static_cast(transb))); + packa = static_cast(std::toupper(static_cast(packa))); + packb = static_cast(std::toupper(static_cast(packb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char transa_cpy = transa; + char transb_cpy = transb; + char packa_cpy = packa; + char packb_cpy = packb; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, m, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, k, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) gemm_compute_( transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + gemm_compute_blis_impl( transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_gemm_compute( storage, transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED @@ -453,4 +705,46 @@ static void gemm_compute( char storage, char transa, char transb, char packa, ch #else throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "packa", packa, packa_cpy ); + computediff( "packb", packb, packb_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, m, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, m, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, k, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, k, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp index db293c0433..cd05b1fc8f 100644 --- a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,13 +39,61 @@ #include "inc/check_error.h" template -class GEMM_Compute_IIT_ERS_Test : public ::testing::Test {}; +class gemm_compute_IIT_ERS : public ::testing::Test {}; typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(GEMM_Compute_IIT_ERS_Test, TypeParam); +TYPED_TEST_SUITE(gemm_compute_IIT_ERS, TypeParam); using namespace testinghelpers::IIT; -#ifdef TEST_BLAS +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +// When info == 1 +TYPED_TEST(gemm_compute_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( 'x', TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( 'x', TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for storage. + gemm_compute( 'x', TRANS, TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) /* Incorrect Input Testing(IIT) @@ -62,138 +110,343 @@ using namespace testinghelpers::IIT; */ // When info == 1 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transa) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_transa) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for TRANS value for A. - gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif } // When info == 2 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transb) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_transb) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for TRANS value for A. - gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for B. + gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif } // When info == 3 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_lt_zero) +TYPED_TEST(gemm_compute_IIT_ERS, m_lt_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 3 ); +#endif } // When info == 4 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_lt_zero) +TYPED_TEST(gemm_compute_IIT_ERS, n_lt_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for n. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif } // When info == 5 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, k_lt_zero) +TYPED_TEST(gemm_compute_IIT_ERS, k_lt_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for k. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif } // When info == 7 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_lda) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_lda) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA - 1, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for lda. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA - 1, nullptr, LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 7 ); +#endif } // When info == 9 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldb) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_ldb) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB - 1, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for ldb. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB - 1, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif } // When info == 12 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc_lt_zero) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_ldc_lt_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, -1 ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, -1 ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, -1); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 12 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for ldc. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), -1 ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 12 ); +#endif } // When info == 12 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc) +TYPED_TEST(gemm_compute_IIT_ERS, invalid_ldc) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC - 1); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 12 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for ldc. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC - 1 ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 12 ); +#endif } /* @@ -206,32 +459,171 @@ TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc) */ // When m = 0 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_eq_zero) +TYPED_TEST(gemm_compute_IIT_ERS, m_eq_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } // When n = 0 -TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_eq_zero) +TYPED_TEST(gemm_compute_IIT_ERS, n_eq_zero) { - using T = TypeParam; - // Defining the C matrix with values for debugging purposes - std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); - - // Copy so that we check that the elements of C are not modified. - std::vector c_ref(c); - // Call BLIS Gemm with a invalid value for m. - gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); - // Use bitwise comparison (no threshold). - computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When k is 0 and beta is 1 +TYPED_TEST(gemm_compute_IIT_ERS, k_zero_beta_one) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, 0, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// zero alpha and zero beta - set C to 0 +TYPED_TEST(gemm_compute_IIT_ERS, ZeroAlpha_ZeroBeta) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initzero( beta ); + + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + // Copy so that we check that the elements of C are not modified. + std::vector zero_mat = testinghelpers::get_random_matrix(0, 0, STORAGE, 'n', M, N, LDB); + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + + // Enable packing of A matrix to accound for alpha = 0 scaling. + gemm_compute( STORAGE, TRANS, TRANS, 'P', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), zero_mat.data(), LDC); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// zero alpha and non-zero/non-unit beta - scale C only +TYPED_TEST(gemm_compute_IIT_ERS, ZeroAlpha_OtherBeta) +{ + using T = TypeParam; + + T alpha, beta; + testinghelpers::initzero( alpha ); + beta = T{2.0}; + double thresh = testinghelpers::getEpsilon(); + + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', M, N, LDC); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + testinghelpers::ref_gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, alpha, + a.data(), LDA, b.data(), LDB, beta, c_ref.data(), LDC ); + + // Test with all arguments correct except for the value we are choosing to test. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + // Use bitwise comparison (no threshold). + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC, thresh); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + #endif diff --git a/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp index ea574eb723..fb6ff1dc87 100644 --- a/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include #include "test_gemm_compute.h" -class SGemmComputeTest : +class sgemmComputeGeneric : public ::testing::TestWithParam> {}; -TEST_P(SGemmComputeTest, RandomData) +TEST_P( sgemmComputeGeneric, API ) { // printf("SGemmCompute_test!!\n"); using T = float; @@ -86,8 +86,18 @@ TEST_P(SGemmComputeTest, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - float intermediate = (float)m*n*k; - float thresh = 10*intermediate*testinghelpers::getEpsilon(); + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + //thresh = (8*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -95,55 +105,13 @@ TEST_P(SGemmComputeTest, RandomData) test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class SGemmComputeTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - char pka = std::get<3>(str.param); - char pkb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - gtint_t k = std::get<7>(str.param); - float alpha = std::get<8>(str.param); - float beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sgemm_compute_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sgemm_compute"; -#else //#elif TEST_BLIS_TYPED - // BLIS interface not yet implemented for pack and compute APIs. - std::string str_name = "blis_sgemm_compute"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + pka + pkb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - SGemmComputeTest, + sgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -160,15 +128,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::SGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); INSTANTIATE_TEST_SUITE_P( TinySizes, - SGemmComputeTest, + sgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -185,15 +153,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::SGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); INSTANTIATE_TEST_SUITE_P( DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes - SGemmComputeTest, + sgemmComputeGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -210,5 +178,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::SGemmComputeTestPrint() + ::gemm_computeGeneticPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h index a9109d5abc..708cb401a2 100644 --- a/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h +++ b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -55,7 +55,14 @@ void test_gemm_compute( char storage, char trnsa, char trnsb, char pcka, char pc //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + std::vector c( testinghelpers::matsize( storage, 'n', m, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, c.data(), 'n', ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, c.data(), 'n', ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -75,5 +82,51 @@ void test_gemm_compute( char storage, char trnsa, char trnsb, char pcka, char pc //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class gemm_computeGeneticPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char transa = std::get<1>(str.param); + char transb = std::get<2>(str.param); + char pka = std::get<3>(str.param); + char pkb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + T alpha = std::get<8>(str.param); + T beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_pka_" + std::string(&pka, 1); + str_name += "_pkb_" + std::string(&pkb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp index 07aed996bb..47eaf09e46 100644 --- a/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_gemmt.h" -class cgemmtTest : +class cgemmtGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmtTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmtGeneric); -TEST_P(cgemmtTest, RandomData) +TEST_P( cgemmtGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -81,7 +81,18 @@ TEST_P(cgemmtTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 10*n*k*testinghelpers::getEpsilon(); + // Check gtestsuite gemmt.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -89,55 +100,15 @@ TEST_P(cgemmtTest, RandomData) test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class cgemmtTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - scomplex alpha = std::get<6>(str.param); - scomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cgemmt_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cgemmt"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cgemmt"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Disable tests for BLIS_TYPED case due to compiler errors. #ifndef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - cgemmtTest, + cgemmtGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -152,6 +123,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), - ::cgemmtTestPrint() + ::gemmtGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/gemmt/dgemmt_evt.cpp b/gtestsuite/testsuite/level3/gemmt/dgemmt_evt.cpp new file mode 100644 index 0000000000..0fc2f1d948 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemmt/dgemmt_evt.cpp @@ -0,0 +1,136 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemmt.h" + +class dgemmtEVT : + public ::testing::TestWithParam> {}; // exception value for C matrix + +TEST_P( dgemmtEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies if the upper or lower triangular part of C is used + char uplo = std::get<1>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<2>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + // specifies alpha value + T alpha = std::get<6>(GetParam()); + // specifies beta value + T beta = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + gtint_t ldc_inc = std::get<10>(GetParam()); + T aexval = std::get<11>(GetParam()); + T bexval = std::get<12>(GetParam()); + T cexval = std::get<13>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite gemmt.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, thresh, false, true, aexval, bexval, cexval ); +} + +static double AOCL_NAN = std::numeric_limits::quiet_NaN(); +static double AOCL_INF = std::numeric_limits::infinity(); + +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + Native, + dgemmtEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(7, 800), // n + ::testing::Values(7, 800), // k + ::testing::Values(2.4, AOCL_NAN/*, AOCL_INF, -AOCL_INF*/), // alpha //commented values fail + ::testing::Values(2.4/*, AOCL_NAN*/, AOCL_INF, -AOCL_INF), // beta //commented values fail + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)), // increment to the leading dim of c + ::testing::Values(0.0, AOCL_NAN, AOCL_INF, -AOCL_INF), // extreme value for A matrix + ::testing::Values(0.0, AOCL_NAN, AOCL_INF, -AOCL_INF), // extreme value for B matrix + ::testing::Values(0.0, AOCL_NAN, AOCL_INF, -AOCL_INF) // extreme value for B matrix + ), + ::gemmtEVTPrint() + ); +#endif diff --git a/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp index c31260def4..9d5c627eac 100644 --- a/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,22 +35,23 @@ #include #include "test_gemmt.h" -class dgemmtTest : - public ::testing::TestWithParam> {}; +class dgemmtGeneric : + public ::testing::TestWithParam> {}; // is memory test -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmtTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmtGeneric); -TEST_P(dgemmtTest, RandomData) +TEST_P( dgemmtGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -79,76 +80,97 @@ TEST_P(dgemmtTest, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); + bool is_mem_test = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 10*n*k*testinghelpers::getEpsilon(); + // Check gtestsuite gemmt.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, is_mem_test ); } -class dgemmtTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - char uplo = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - double alpha = std::get<6>(str.param); - double beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dgemmt_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dgemmt"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dgemmt"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; -// Disable tests for BLIS_TYPED case due to compiler errors. #ifndef TEST_BLIS_TYPED -// Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, - dgemmtTest, + skinny_fringe_cases, + dgemmtGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Range(gtint_t(1), gtint_t(30), 1), // n + ::testing::Range(gtint_t(1), gtint_t(30), 1), // k + ::testing::Values(1.0, 0.0, -2.4, 3.1), // alpha + ::testing::Values(1.0, 0.0, -2.4, 3.1), // beta + ::testing::Values(gtint_t(0), gtint_t(153)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(122)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(195)), // increment to the leading dim of c + ::testing::Values(true, false) // is memory test + ), + ::gemmtMemGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + skinny, + dgemmtGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(35, 537, 799), // n + ::testing::Values(35, 537, 799), // k + ::testing::Values(1.0, 0.0, -2.4, 3.1), // alpha + ::testing::Values(1.0, 0.0, -2.4, 3.1), // beta + ::testing::Values(gtint_t(0), gtint_t(153)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(122)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(195)), // increment to the leading dim of c + ::testing::Values(true, false) // is memory test + ), + ::gemmtMemGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + large, + dgemmtGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' +#ifndef TEST_BLAS_LIKE + ,'r' #endif ), // storage format ::testing::Values('u','l'), // uplo u:upper, l:lower - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','c','t'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Range(gtint_t(10), gtint_t(31), 10), // k - ::testing::Values(2.0), // alpha - ::testing::Values(3.0), // beta - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Values(800, 1500), // n + ::testing::Values(800, 1500), // k + ::testing::Values(1.0, 0.0, -2.4, 3.1), // alpha + ::testing::Values(1.0, 0.0, -2.4, 3.1), // beta + ::testing::Values(gtint_t(0), gtint_t(153)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(122)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(195)), // increment to the leading dim of c + ::testing::Values(true, false) // is memory test ), - ::dgemmtTestPrint() + ::gemmtMemGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/gemmt/gemmt.h b/gtestsuite/testsuite/level3/gemmt/gemmt.h index a9a92821e0..8ea2bcaa2a 100644 --- a/gtestsuite/testsuite/level3/gemmt/gemmt.h +++ b/gtestsuite/testsuite/level3/gemmt/gemmt.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -83,6 +84,22 @@ static void gemmt_(char uplo, char transa, char transb, gtint_t n, gtint_t k, T* throw std::runtime_error("Error in testsuite/level3/gemmt.h: Invalid typename in gemmt_()."); } +template +static void gemmt_blis_impl(char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + sgemmt_blis_impl( &uplo, &transa, &transb, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + dgemmt_blis_impl( &uplo, &transa, &transb, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + cgemmt_blis_impl( &uplo, &transa, &transb, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zgemmt_blis_impl( &uplo, &transa, &transb, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemmt.h: Invalid typename in gemmt_blis_impl()."); +} + template static void cblas_gemmt(char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, @@ -159,12 +176,55 @@ template static void gemmt( char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uplo = static_cast(std::toupper(static_cast(uplo))); + transa = static_cast(std::toupper(static_cast(transa))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uplo_cpy = uplo; + char transa_cpy = transa; + char transb_cpy = transb; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, k, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) gemmt_( uplo, transa, transb, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/gemmt.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + gemmt_blis_impl( uplo, transa, transb, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemmt.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_gemmt( storage, uplo, transa, transb, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED @@ -173,4 +233,44 @@ static void gemmt( char storage, char uplo, char transa, char transb, gtint_t n, #else throw std::runtime_error("Error in testsuite/level3/gemmt.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, n, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, k, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, k, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/gemmt/gemmt_IIT_ERS.cpp b/gtestsuite/testsuite/level3/gemmt/gemmt_IIT_ERS.cpp new file mode 100644 index 0000000000..93d41927e2 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemmt/gemmt_IIT_ERS.cpp @@ -0,0 +1,507 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "gemmt.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class gemmt_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS calls for GEMMT +TYPED_TEST_SUITE(gemmt_IIT_ERS, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +// When info == 1 +TYPED_TEST(gemmt_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemmt( 'x', UPLO, TRANS, TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( 'x', UPLO, TRANS, TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM): + 1. When UPLO != 'L' || UPLO != 'U' (info = 1) + 2. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 2) + 3. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 3) + 4. When n < 0 (info = 4) + 5. When k < 0 (info = 5) + 6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value + 7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value + 8. When ldc < max(1, n) (info = 13) + +*/ + +// When info == 1 +TYPED_TEST(gemmt_IIT_ERS, invalid_uploa) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, 'A', TRANS, TRANS, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, 'A', TRANS, TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, 'A', TRANS, TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +// When info == 2 +TYPED_TEST(gemmt_IIT_ERS, invalid_transa) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, 'A', TRANS, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, 'A', TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, 'A', TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif +} + +// When info == 3 +TYPED_TEST(gemmt_IIT_ERS, invalid_transb) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, 'A', N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, 'A', N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, 'A', N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif +} + +// When info == 4 +TYPED_TEST(gemmt_IIT_ERS, n_lt_zero) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, -1, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, -1, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 4 ); +#endif +} + +// When info == 5 +TYPED_TEST(gemmt_IIT_ERS, k_lt_zero) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, N, -1, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, -1, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif +} + +// When info == 8 +TYPED_TEST(gemmt_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, nullptr, LDA - 1, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, a.data(), LDA - 1, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 8 ); +#endif +} + +// When info == 10 +TYPED_TEST(gemmt_IIT_ERS, invalid_ldb) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB - 1, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 10 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB - 1, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 10 ); +#endif +} + +// When info == 13 +TYPED_TEST(gemmt_IIT_ERS, invalid_ldc) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC - 1 ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 13 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC - 1 ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 13 ); +#endif +} + +/* + Early Return Scenarios(ERS) : + + The GEMMt API is expected to return early in the following cases: + + 1. When n == 0. + 2. When (alpha == 0 or k == 0) and beta == 1. + +*/ + +// When n is 0 +TYPED_TEST(gemmt_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + gemmt( STORAGE, UPLO, TRANS, TRANS, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); +#else + gemmt( STORAGE, UPLO, TRANS, TRANS, 0, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, 0, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When alpha is 0 and beta is 1 +TYPED_TEST(gemmt_IIT_ERS, alpha_zero_beta_one) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, K, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +// When k is 0 and beta is 1 +TYPED_TEST(gemmt_IIT_ERS, k_zero_beta_one) +{ + using T = TypeParam; + T alpha, beta; + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + gemmt( STORAGE, UPLO, TRANS, TRANS, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC); + std::vector a = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, K, LDA); + std::vector b = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', K, N, LDB); + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + + gemmt( STORAGE, UPLO, TRANS, TRANS, N, 0, &alpha, a.data(), LDA, b.data(), LDB, &beta, c.data(), LDC ); + computediff( "C", STORAGE, N, N, c.data(), c_ref.data(), LDC ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif + diff --git a/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp index e067a684e7..cbe54ef327 100644 --- a/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_gemmt.h" -class sgemmtTest : +class sgemmtGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmtTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmtGeneric); -TEST_P(sgemmtTest, RandomData) +TEST_P( sgemmtGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -81,7 +81,17 @@ TEST_P(sgemmtTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 10*n*k*testinghelpers::getEpsilon(); + // Check gtestsuite gemmt.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -89,53 +99,15 @@ TEST_P(sgemmtTest, RandomData) test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class sgemmtTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char tsa = std::get<1>(str.param); - char tsb = std::get<2>(str.param); - char uplo = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - float alpha = std::get<6>(str.param); - float beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "sgemmt_"; -#elif TEST_CBLAS - std::string str_name = "cblas_sgemmt"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_sgemmt"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Disable tests for BLIS_TYPED case due to compiler errors. #ifndef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - sgemmtTest, + sgemmtGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -150,6 +122,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), - ::sgemmtTestPrint() + ::gemmtGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/gemmt/test_gemmt.h b/gtestsuite/testsuite/level3/gemmt/test_gemmt.h index 2afaba222d..d28cf4f388 100644 --- a/gtestsuite/testsuite/level3/gemmt/test_gemmt.h +++ b/gtestsuite/testsuite/level3/gemmt/test_gemmt.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -39,41 +39,233 @@ #include "inc/check_error.h" #include #include +#include "common/testing_helpers.h" template -void test_gemmt( char storage, char uplo, char trnsa, char trnsb, gtint_t n, +void test_gemmt( char storage, char uploc, char transa, char transb, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, - T beta, double thresh ) + T beta, double thresh, bool is_mem_test=false, bool is_evt_test=false, + T evt_a=T{0.0}, T evt_b=T{0.0}, T evt_c=T{0.0} ) { // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, n, k, lda_inc ); - gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random numbers //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, n, k, lda ); - std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', n, n, ldc ); + T *a_ptr, *b_ptr, *c_ptr; + dim_t size_a = testinghelpers::matsize(storage, transa, n, k, lda) * sizeof(T); + testinghelpers::ProtectedBuffer a(size_a, false, is_mem_test ); + a_ptr = (T*)a.greenzone_1; + testinghelpers::datagenerators::randomgenerators( -2, 8, storage, n, k, a_ptr, transa, lda); + + if ( is_evt_test ) + { + dim_t n_rand = rand() % (std::min)(n, k); + dim_t k_rand = rand() % (std::min)(n, k); + a_ptr[n_rand + k_rand * lda] = evt_a; + } + + dim_t size_b = testinghelpers::matsize(storage, transb, k, n, ldb) * sizeof(T); + testinghelpers::ProtectedBuffer b(size_b, false, is_mem_test ); + b_ptr = (T*)b.greenzone_1; + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, k, n, b_ptr, transb, ldb); + + if ( is_evt_test ) + { + dim_t n_rand = rand() % (std::min)(k, n); + dim_t k_rand = rand() % (std::min)(k, n); + b_ptr[n_rand + k_rand * ldb] = evt_b; + } + + dim_t size_c = testinghelpers::matsize(storage, 'n', n, n, ldc) * sizeof(T); + testinghelpers::ProtectedBuffer c(size_c, false, is_mem_test ); + c_ptr = (T*)c.greenzone_1; + if (beta != testinghelpers::ZERO()) + { + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, n, n, c_ptr, 'n', ldc); + if ( is_evt_test ) + { + dim_t n_rand = rand() % n; + dim_t k_rand = rand() % n; + c_ptr[n_rand + k_rand * ldc] = evt_c; + } + } + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, n, n, c_ptr, 'n', ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. - std::vector c_ref(c); + std::vector c_ref(testinghelpers::matsize(storage, 'n', n, n, ldc)); + memcpy(c_ref.data(), c_ptr, size_c); - //---------------------------------------------------------- - // Call BLIS function - //---------------------------------------------------------- - gemmt( storage, uplo, trnsa, trnsb, n, k, &alpha, a.data(), lda, - b.data(), ldb, &beta, c.data(), ldc ); + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemmt( storage, uploc, transa, transb, n, k, &alpha, a_ptr, lda, + b_ptr, ldb, &beta, c_ptr, ldc ); + if (is_mem_test) + { + memcpy(a.greenzone_2, a.greenzone_1, size_a); + memcpy(b.greenzone_2, b.greenzone_1, size_b); + memcpy(c.greenzone_2, c_ref.data(), size_c); + + gemmt( storage, uploc, transa, transb, n, k, &alpha, (T*)a.greenzone_2, lda, + (T*)b.greenzone_2, ldb, &beta, (T*)c.greenzone_2, ldc ); + } + + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_gemmt( storage, uplo, trnsa, trnsb, n, k, alpha, - a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); + testinghelpers::ref_gemmt( storage, uploc, transa, transb, n, k, alpha, + a_ptr, lda, b_ptr, ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, n, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, n, n, c_ptr, c_ref.data(), ldc, thresh, is_evt_test ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class gemmtGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploc = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char transb = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + T beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploc_" + std::string(&uploc, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; + +template +class gemmtMemGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploc = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char transb = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + T beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + bool is_mem_test = std::get<11>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploc_" + std::string(&uploc, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + str_name = str_name + (is_mem_test ? "_mem_test_enabled" : "_mem_test_disabled"); + return str_name; + } +}; + +// Test-case logger : Used to print the test-case details based on parameters +template +class gemmtEVTPrint +{ +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uploc = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char transb = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + T beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + T aexval = std::get<11>(str.param); + T bexval = std::get<12>(str.param); + T cexval = std::get<13>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uploc_" + std::string(&uploc, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name = str_name + "_ex_a_" + testinghelpers::get_value_string(aexval); + str_name = str_name + "_ex_b_" + testinghelpers::get_value_string(bexval); + str_name = str_name + "_ex_c_" + testinghelpers::get_value_string(cexval); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp index 7c8a4c8ecf..6dd4bd7820 100644 --- a/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_gemmt.h" -class zgemmtTest : +class zgemmtGeneric : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmtTest); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmtGeneric); -TEST_P(zgemmtTest, RandomData) +TEST_P( zgemmtGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -81,7 +81,18 @@ TEST_P(zgemmtTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = (std::max)(n,k)*testinghelpers::getEpsilon(); + // Check gtestsuite gemmt.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -89,55 +100,15 @@ TEST_P(zgemmtTest, RandomData) test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class zgemmtTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t n = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - dcomplex alpha = std::get<6>(str.param); - dcomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zgemmt_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zgemmt"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zgemmt"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(n); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Disable tests for BLIS_TYPED case due to compiler errors. #ifndef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zgemmtTest, + zgemmtGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -152,6 +123,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), - ::zgemmtTestPrint() + ::gemmtGenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp b/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp index 173aa8777b..c4bcfbdb6a 100644 --- a/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp +++ b/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_hemm.h" -class chemmTest : +class chemmGeneric : public ::testing::TestWithParam> {}; -TEST_P(chemmTest, RandomData) +TEST_P( chemmGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -83,7 +83,22 @@ TEST_P(chemmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); + // Check gtestsuite hemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; + double adj = 2.5; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = adj*(3*m+1)*testinghelpers::getEpsilon(); + else + thresh = adj*(3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,53 +106,13 @@ TEST_P(chemmTest, RandomData) test_hemm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class chemmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - scomplex alpha = std::get<7>(str.param); - scomplex beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "chemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_chemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_chemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - chemmTest, + chemmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -153,5 +128,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), - ::chemmTestPrint() + ::hemmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/hemm/hemm.h b/gtestsuite/testsuite/level3/hemm/hemm.h index 1cc0ca1473..427dec7bc0 100644 --- a/gtestsuite/testsuite/level3/hemm/hemm.h +++ b/gtestsuite/testsuite/level3/hemm/hemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -82,6 +83,18 @@ static void hemm_(char side, char uplo, gtint_t m, gtint_t n, T* alpha, throw std::runtime_error("Error in testsuite/level3/hemm.h: Invalid typename in hemm_()."); } +template +static void hemm_blis_impl(char side, char uplo, gtint_t m, gtint_t n, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + chemm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zhemm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/hemm.h: Invalid typename in hemm_blis_impl()."); +} + template static void cblas_hemm(char storage, char side, char uplo, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, @@ -151,12 +164,59 @@ template static void hemm( char storage, char side, char uplo, char conja, char transb, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + side = static_cast(std::toupper(static_cast(side))); + uplo = static_cast(std::toupper(static_cast(uplo))); + conja = static_cast(std::toupper(static_cast(conja))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char side_cpy = side; + char uplo_cpy = uplo; + char conja_cpy = conja; + char transb_cpy = transb; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, 'n', mn, mn, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, m, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) hemm_( side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/hemm.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + hemm_blis_impl( side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/hemm.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_hemm( storage, side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED @@ -164,4 +224,42 @@ static void hemm( char storage, char side, char uplo, char conja, char transb, g #else throw std::runtime_error("Error in testsuite/level3/hemm.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "side", side, side_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "conja", conja, conja_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, mn, mn, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, m, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, m, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/hemm/test_hemm.h b/gtestsuite/testsuite/level3/hemm/test_hemm.h index a55510bf04..e64798ba0e 100644 --- a/gtestsuite/testsuite/level3/hemm/test_hemm.h +++ b/gtestsuite/testsuite/level3/hemm/test_hemm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -60,7 +60,15 @@ void test_hemm( char storage, char side, char uplo, char conja, char transb, // that code operates as expected. std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, uplo, k, lda ); std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + std::vector c( testinghelpers::matsize( storage, 'n', m, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, c.data(), 'n', ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, c.data(), 'n', ldc, testinghelpers::aocl_extreme() ); + } + // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -79,5 +87,50 @@ void test_hemm( char storage, char side, char uplo, char conja, char transb, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class hemmGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uplo = std::get<2>(str.param); + char conja = std::get<3>(str.param); + char transb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T alpha = std::get<7>(str.param); + T beta = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + gtint_t ldb_inc = std::get<10>(str.param); + gtint_t ldc_inc = std::get<11>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t k = ((side == 'l')||(side == 'L'))? m : n; + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', k, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp b/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp index f509cb8881..217a90d0c5 100644 --- a/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp +++ b/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_hemm.h" -class zhemmTest : +class zhemmGeneric : public ::testing::TestWithParam> {}; -TEST_P(zhemmTest, RandomData) +TEST_P( zhemmGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -83,7 +83,21 @@ TEST_P(zhemmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 10*m*n*testinghelpers::getEpsilon(); + // Check gtestsuite hemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,53 +105,13 @@ TEST_P(zhemmTest, RandomData) test_hemm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class zhemmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - dcomplex alpha = std::get<7>(str.param); - dcomplex beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zhemm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zhemm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zhemm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zhemmTest, + zhemmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -153,5 +127,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(6)) // increment to the leading dim of c ), - ::zhemmTestPrint() + ::hemmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp b/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp index b87a833950..137bb70d9c 100644 --- a/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp +++ b/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her2k.h" -class cher2kTest : +class cher2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(cher2kTest, RandomData) +TEST_P( cher2kGeneric, API ) { using T = scomplex; using RT = typename testinghelpers::type_info::real_type; @@ -64,9 +64,9 @@ TEST_P(cher2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -80,73 +80,45 @@ TEST_P(cher2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*m*k*testinghelpers::getEpsilon(); + // Check gtestsuite her2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; + double adj = 2.5; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == 0.0f || beta == 1.0f)) + thresh = 0.0; + else + thresh = adj*(6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_her2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class cher2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - scomplex alpha = std::get<6>(str.param); - float beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cher2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cher2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cher2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - cher2kTest, + cher2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha ::testing::Values(-3.0, 2.0), // beta ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), - ::cher2kTestPrint() + ::her2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/her2k/her2k.h b/gtestsuite/testsuite/level3/her2k/her2k.h index 76ea95f3b4..7ffc1ff1c5 100644 --- a/gtestsuite/testsuite/level3/her2k/her2k.h +++ b/gtestsuite/testsuite/level3/her2k/her2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -47,7 +48,7 @@ the matrix multiplication * @param[in] transb specifies the form of op( B ) to be used in the matrix multiplication - * @param[in] m specifies the number of rows and cols of the matrix + * @param[in] n specifies the number of rows and cols of the matrix op( A ) and rows of the matrix C and B * @param[in] k specifies the number of columns of the matrix op( B ) and the number of columns of the matrix C @@ -65,20 +66,32 @@ */ template::real_type> -static void her2k_(char uplo, char transa, gtint_t m, gtint_t k, T* alpha, +static void her2k_(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, RT* beta, T* cp, gtint_t ldc ) { if constexpr (std::is_same::value) - cher2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + cher2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else if constexpr (std::is_same::value) - zher2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + zher2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else throw std::runtime_error("Error in testsuite/level3/her2k.h: Invalid typename in her2k_()."); } +template::real_type> +static void her2k_blis_impl(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, RT* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + cher2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zher2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/her2k.h: Invalid typename in her2k_blis_impl()."); +} + template::real_type> static void cblas_her2k(char storage, char uplo, char transa, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, RT* beta, T* cp, gtint_t ldc) { enum CBLAS_ORDER cblas_order; @@ -90,16 +103,16 @@ static void cblas_her2k(char storage, char uplo, char transa, testinghelpers::char_to_cblas_trans( transa, &cblas_transa ); if constexpr (std::is_same::value) - cblas_cher2k( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, bp, ldb, *beta, cp, ldc ); + cblas_cher2k( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, bp, ldb, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_zher2k( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, bp, ldb, *beta, cp, ldc ); + cblas_zher2k( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, bp, ldb, *beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/her2k.h: Invalid typename in cblas_her2k()."); } template::real_type> static void typed_her2k(char storage, char uplo, char trnsa, char trnsb, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, RT* beta, T* cp, gtint_t ldc) { trans_t transa, transb; @@ -114,7 +127,7 @@ static void typed_her2k(char storage, char uplo, char trnsa, char trnsb, rsa=rsb=rsc=1; csa=csb=csc=1; - /* a = m x k b = k x n c = m x n */ + /* a = n x k b = k x n c = n x n */ if( (storage == 'c') || (storage == 'C') ) { csa = lda ; csb = ldb ; @@ -127,32 +140,113 @@ static void typed_her2k(char storage, char uplo, char trnsa, char trnsb, } if constexpr (std::is_same::value) - bli_sher2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_sher2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_dher2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_dher2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_cher2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_cher2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_zher2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_zher2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else throw std::runtime_error("Error in testsuite/level3/her2k.h: Invalid typename in typed_her2k()."); } template::real_type> -static void her2k( char storage, char uplo, char transa, char transb, gtint_t m, gtint_t k, +static void her2k( char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, RT* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uplo = static_cast(std::toupper(static_cast(uplo))); + transa = static_cast(std::toupper(static_cast(transa))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uplo_cpy = uplo; + char transa_cpy = transa; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + RT* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, n, k, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) - her2k_( uplo, transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + her2k_( uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/her2k.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + her2k_blis_impl( uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/her2k.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS - cblas_her2k( storage, uplo, transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + cblas_her2k( storage, uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED - typed_her2k( storage, uplo, transa, transb, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + typed_her2k( storage, uplo, transa, transb, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #else throw std::runtime_error("Error in testsuite/level3/her2k.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, n, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, n, k, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, k, n, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/her2k/test_her2k.h b/gtestsuite/testsuite/level3/her2k/test_her2k.h index 18ab391cd7..3302e67f1a 100644 --- a/gtestsuite/testsuite/level3/her2k/test_her2k.h +++ b/gtestsuite/testsuite/level3/her2k/test_her2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -42,23 +42,30 @@ template::real_type> void test_her2k( char storage, char uplo, char transa, char transb, - gtint_t m, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, RT beta, double thresh ) { // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); - gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, k, ldb_inc ); - gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, n, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random numbers //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); - std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, k, ldb ); - // Since matrix C, stored in c, is symmetric and we only use the upper or lower - // part in the computation of her2k and zero-out the rest to ensure - // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc ); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, n, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, n, k, ldb ); + std::vector c( testinghelpers::matsize( storage, 'n', n, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + // Since matrix C, stored in c, is symmetric and we only use the upper or lower + // part in the computation of her2k and zero-out the rest to ensure + // that code operates as expected. + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, uplo, n, c.data(), ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, n, c.data(), uplo, ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -66,17 +73,59 @@ void test_her2k( char storage, char uplo, char transa, char transb, //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - her2k( storage, uplo, transa, transb, m, k, &alpha, a.data(), lda, + her2k( storage, uplo, transa, transb, n, k, &alpha, a.data(), lda, b.data(), ldb, &beta, c.data(), ldc ); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_her2k( storage, uplo, transa, transb, m, k, &alpha, + testinghelpers::ref_her2k( storage, uplo, transa, transb, n, k, &alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, m, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, n, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template ::real_type> +class her2kGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uplo = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char transb = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + RT beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, n, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp b/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp index 2ae305c086..fb0109e43a 100644 --- a/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp +++ b/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_her2k.h" -class zher2kTest : +class zher2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(zher2kTest, RandomData) +TEST_P( zher2kGeneric, API ) { using T = dcomplex; using RT = typename testinghelpers::type_info::real_type; @@ -64,9 +64,9 @@ TEST_P(zher2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -80,73 +80,45 @@ TEST_P(zher2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*m*k*testinghelpers::getEpsilon(); + // Check gtestsuite her2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; + double adj = 2.5; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == 0.0 || beta == 1.0)) + thresh = 0.0; + else + thresh = adj*(6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_her2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class zher2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - dcomplex alpha = std::get<6>(str.param); - double beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zher2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zher2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zher2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zher2kTest, + zher2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}), // alpha ::testing::Values(4.0, -1.0), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), - ::zher2kTestPrint() + ::her2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/herk/cherk_generic.cpp b/gtestsuite/testsuite/level3/herk/cherk_generic.cpp index 868b637d3a..c604598779 100644 --- a/gtestsuite/testsuite/level3/herk/cherk_generic.cpp +++ b/gtestsuite/testsuite/level3/herk/cherk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_herk.h" -class cherkTest : +class cherkGeneric : public ::testing::TestWithParam> {}; -TEST_P(cherkTest, RandomData) +TEST_P( cherkGeneric, API ) { using T = scomplex; using RT = typename testinghelpers::type_info::real_type; @@ -60,8 +60,8 @@ TEST_P(cherkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -75,67 +75,42 @@ TEST_P(cherkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite herk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == 0.0f || k == 0) && (beta == 0.0f || beta == 1.0f)) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_herk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_herk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class cherkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - float beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "cherk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_cherk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cherk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - cherkTest, + cherkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','c'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(-2.0, 3.0), // alpha ::testing::Values(4.0, -1.0), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), - ::cherkTestPrint() + ::herkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/herk/herk.h b/gtestsuite/testsuite/level3/herk/herk.h index 6aab4355dc..23539adf59 100644 --- a/gtestsuite/testsuite/level3/herk/herk.h +++ b/gtestsuite/testsuite/level3/herk/herk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -59,20 +60,32 @@ */ template::real_type> -static void herk_(char uplo, char transa, gtint_t m, gtint_t k, RT* alpha, +static void herk_(char uplo, char transa, gtint_t n, gtint_t k, RT* alpha, T* ap, gtint_t lda, RT* beta, T* cp, gtint_t ldc ) { if constexpr (std::is_same::value) - cherk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + cherk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else if constexpr (std::is_same::value) - zherk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + zherk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else throw std::runtime_error("Error in testsuite/level3/herk.h: Invalid typename in herk_()."); } +template::real_type> +static void herk_blis_impl(char uplo, char transa, gtint_t n, gtint_t k, RT* alpha, + T* ap, gtint_t lda, RT* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + cherk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zherk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/herk.h: Invalid typename in herk_blis_impl()."); +} + template::real_type> static void cblas_herk(char storage, char uplo, char trnsa, - gtint_t m, gtint_t k, RT* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, RT* alpha, T* ap, gtint_t lda, RT* beta, T* cp, gtint_t ldc) { enum CBLAS_ORDER cblas_order; @@ -84,16 +97,16 @@ static void cblas_herk(char storage, char uplo, char trnsa, testinghelpers::char_to_cblas_trans( trnsa, &cblas_transa ); if constexpr (std::is_same::value) - cblas_cherk( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, *beta, cp, ldc ); + cblas_cherk( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_zherk( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, *beta, cp, ldc ); + cblas_zherk( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, *beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/herk.h: Invalid typename in cblas_herk()."); } template::real_type> static void typed_herk(char storage, char uplo, char trnsa, - gtint_t m, gtint_t k, RT* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, RT* alpha, T* ap, gtint_t lda, RT* beta, T* cp, gtint_t ldc) { trans_t transa; @@ -106,7 +119,7 @@ static void typed_herk(char storage, char uplo, char trnsa, rsa=rsc=1; csa=csc=1; - /* a = m x k c = m x m */ + /* a = n x k c = n x n */ if( (storage == 'c') || (storage == 'C') ) { csa = lda ; csc = ldc ; @@ -117,31 +130,94 @@ static void typed_herk(char storage, char uplo, char trnsa, } if constexpr (std::is_same::value) - bli_sherk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_sherk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_dherk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_dherk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_cherk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_cherk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_zherk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_zherk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else throw std::runtime_error("Error in testsuite/level3/herk.h: Invalid typename in typed_herk()."); } template::real_type> -static void herk( char storage, char uplo, char transa, gtint_t m, gtint_t k, +static void herk( char storage, char uplo, char transa, gtint_t n, gtint_t k, RT* alpha, T* ap, gtint_t lda, RT* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uplo = static_cast(std::toupper(static_cast(uplo))); + transa = static_cast(std::toupper(static_cast(transa))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uplo_cpy = uplo; + char transa_cpy = transa; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + RT* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + RT* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) - herk_( uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + herk_( uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/herk.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + herk_blis_impl( uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/herk.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS - cblas_herk( storage, uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + cblas_herk( storage, uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); #elif TEST_BLIS_TYPED - typed_herk( storage, uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + typed_herk( storage, uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); #else throw std::runtime_error("Error in testsuite/level3/herk.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, n, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/herk/test_herk.h b/gtestsuite/testsuite/level3/herk/test_herk.h index a283366566..bac8ab0263 100644 --- a/gtestsuite/testsuite/level3/herk/test_herk.h +++ b/gtestsuite/testsuite/level3/herk/test_herk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -41,39 +41,84 @@ #include template::real_type> -void test_herk( char storage, char uplo, char transa, gtint_t m, gtint_t k, +void test_herk( char storage, char uplo, char transa, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldc_inc, RT alpha, RT beta, double thresh ) { // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); - gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, transa, m, k, lda ); - // Since matrix C, stored in c, is symmetric, we only use the upper or lower - // part in the computation of herk and zero-out the rest to ensure - // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix( -8, 12, storage, uplo, m, ldc ); + std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, transa, n, k, lda ); + std::vector c( testinghelpers::matsize( storage, 'n', n, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + // Since matrix C, stored in c, is symmetric, we only use the upper or lower + // part in the computation of herk and zero-out the rest to ensure + // that code operates as expected. + testinghelpers::datagenerators::randomgenerators( -8, 12, storage, uplo, n, c.data(), ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, n, c.data(), uplo, ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); + //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - herk( storage, uplo, transa, m, k, &alpha, a.data(), lda, + herk( storage, uplo, transa, n, k, &alpha, a.data(), lda, &beta, c.data(), ldc ); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_herk( storage, uplo, transa, m, k, alpha, + testinghelpers::ref_herk( storage, uplo, transa, n, k, alpha, a.data(), lda, beta, c_ref.data(), ldc ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, m, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, n, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template ::real_type> +class herkGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uplo = std::get<1>(str.param); + char transa = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + gtint_t k = std::get<4>(str.param); + RT alpha = std::get<5>(str.param); + RT beta = std::get<6>(str.param); + gtint_t lda_inc = std::get<7>(str.param); + gtint_t ldc_inc = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/herk/zherk_generic.cpp b/gtestsuite/testsuite/level3/herk/zherk_generic.cpp index b3d89854c6..672a6a519d 100644 --- a/gtestsuite/testsuite/level3/herk/zherk_generic.cpp +++ b/gtestsuite/testsuite/level3/herk/zherk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_herk.h" -class zherkTest : +class zherkGeneric : public ::testing::TestWithParam> {}; -TEST_P(zherkTest, RandomData) +TEST_P( zherkGeneric, API ) { using T = dcomplex; using RT = typename testinghelpers::type_info::real_type; @@ -60,8 +60,8 @@ TEST_P(zherkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -75,67 +75,42 @@ TEST_P(zherkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite herk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == 0.0 || k == 0) && (beta == 0.0 || beta == 1.0)) + thresh = 0.0; + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_herk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_herk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class zherkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - double beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zherk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zherk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zherk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zherkTest, + zherkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','c'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(2.0, -1.0), // alpha ::testing::Values(-3.0, 2.0), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), - ::zherkTestPrint() + ::herkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/csymm_generic.cpp b/gtestsuite/testsuite/level3/symm/csymm_generic.cpp index 72e84c9069..e1e5137c6b 100644 --- a/gtestsuite/testsuite/level3/symm/csymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/csymm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symm.h" -class csymmTest : +class csymmGeneric : public ::testing::TestWithParam> {}; -TEST_P(csymmTest, RandomData) +TEST_P( csymmGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -83,7 +83,24 @@ TEST_P(csymmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite symm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // With adjustment for complex data. + double thresh; + double adj = 1.5; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if ( side == 'l' || side == 'L' ) + thresh = adj*(3*m+1)*testinghelpers::getEpsilon(); + else + thresh = adj*(3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,54 +108,13 @@ TEST_P(csymmTest, RandomData) test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class csymmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - scomplex alpha = std::get<7>(str.param); - scomplex beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "csymm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_csymm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_csymm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - csymmTest, + csymmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -154,5 +130,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of c ), - ::csymmTestPrint() + ::symmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp b/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp index 34d4fdb474..10b90ce0d1 100644 --- a/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symm.h" -class dsymmTest : +class dsymmGeneric : public ::testing::TestWithParam> {}; -TEST_P(dsymmTest, RandomData) +TEST_P( dsymmGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -83,7 +83,22 @@ TEST_P(dsymmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 30*m*n*testinghelpers::getEpsilon(); + // Check gtestsuite symm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,52 +106,13 @@ TEST_P(dsymmTest, RandomData) test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class dsymmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - double alpha = std::get<7>(str.param); - double beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsymm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsymm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsymm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dsymmTest, + dsymmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -152,5 +128,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), - ::dsymmTestPrint() + ::symmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp b/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp index 749b7a7fce..86db0240ea 100644 --- a/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symm.h" -class ssymmTest : +class ssymmGeneric : public ::testing::TestWithParam> {}; -TEST_P(ssymmTest, RandomData) +TEST_P( ssymmGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -83,7 +83,22 @@ TEST_P(ssymmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 8*m*n*testinghelpers::getEpsilon(); + // Check gtestsuite symm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,52 +106,13 @@ TEST_P(ssymmTest, RandomData) test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class ssymmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - float alpha = std::get<7>(str.param); - float beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssymm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssymm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssymm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ssymmTest, + ssymmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -152,5 +128,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), - ::ssymmTestPrint() + ::symmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/symm.h b/gtestsuite/testsuite/level3/symm/symm.h index cc97c9304f..fc1faf5b6a 100644 --- a/gtestsuite/testsuite/level3/symm/symm.h +++ b/gtestsuite/testsuite/level3/symm/symm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -86,6 +87,22 @@ static void symm_(char side, char uplo, gtint_t m, gtint_t n, T* alpha, throw std::runtime_error("Error in testsuite/level3/symm.h: Invalid typename in symm_()."); } +template +static void symm_blis_impl(char side, char uplo, gtint_t m, gtint_t n, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + ssymm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + dsymm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + csymm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zsymm_blis_impl( &side, &uplo, &m, &n, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/symm.h: Invalid typename in symm_blis_impl()."); +} + template static void cblas_symm(char storage, char side, char uplo, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, @@ -159,12 +176,59 @@ template static void symm( char storage, char side, char uplo, char conja, char transb, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + side = static_cast(std::toupper(static_cast(side))); + uplo = static_cast(std::toupper(static_cast(uplo))); + conja = static_cast(std::toupper(static_cast(conja))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char side_cpy = side; + char uplo_cpy = uplo; + char conja_cpy = conja; + char transb_cpy = transb; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, 'n', mn, mn, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, m, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) symm_( side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/symm.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + symm_blis_impl( side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/symm.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_symm( storage, side, uplo, m, n, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED @@ -172,4 +236,42 @@ static void symm( char storage, char side, char uplo, char conja, char transb, g #else throw std::runtime_error("Error in testsuite/level3/symm.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "side", side, side_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "conja", conja, conja_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, mn, mn, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, m, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, m, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/symm/test_symm.h b/gtestsuite/testsuite/level3/symm/test_symm.h index cc90d7f52a..402cff1841 100644 --- a/gtestsuite/testsuite/level3/symm/test_symm.h +++ b/gtestsuite/testsuite/level3/symm/test_symm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -60,7 +60,14 @@ void test_symm( char storage, char side, char uplo, char conja, char transb, // that code operates as expected. std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, uplo, k, lda ); std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + std::vector c( testinghelpers::matsize( storage, 'n', m, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, c.data(), 'n', ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, c.data(), 'n', ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -80,5 +87,50 @@ void test_symm( char storage, char side, char uplo, char conja, char transb, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class symmGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uplo = std::get<2>(str.param); + char conja = std::get<3>(str.param); + char transb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T alpha = std::get<7>(str.param); + T beta = std::get<8>(str.param); + gtint_t lda_inc = std::get<9>(str.param); + gtint_t ldb_inc = std::get<10>(str.param); + gtint_t ldc_inc = std::get<11>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_conja_" + std::string(&conja, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t k = ((side == 'l')||(side == 'L'))? m : n; + gtint_t lda = testinghelpers::get_leading_dimension( storage, conja, k, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp b/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp index a6c163816a..a09e205a5a 100644 --- a/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_symm.h" -class zsymmTest : +class zsymmGeneric : public ::testing::TestWithParam> {}; -TEST_P(zsymmTest, RandomData) +TEST_P( zsymmGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -83,7 +83,23 @@ TEST_P(zsymmTest, RandomData) gtint_t ldc_inc = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite symm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -91,54 +107,13 @@ TEST_P(zsymmTest, RandomData) test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class zsymmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uplo = std::get<2>(str.param); - char conja = std::get<3>(str.param); - char tsb = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - dcomplex alpha = std::get<7>(str.param); - dcomplex beta = std::get<8>(str.param); - gtint_t lda_inc = std::get<9>(str.param); - gtint_t ldb_inc = std::get<10>(str.param); - gtint_t ldc_inc = std::get<11>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zsymm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zsymm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zsymm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uplo; - str_name = str_name + "_" + conja + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zsymmTest, + zsymmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -154,5 +129,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), - ::zsymmTestPrint() + ::symmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp index 2ee7903302..af8786bee6 100644 --- a/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2k.h" -class csyr2kTest : +class csyr2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(csyr2kTest, RandomData) +TEST_P( csyr2kGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -63,9 +63,9 @@ TEST_P(csyr2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -79,74 +79,47 @@ TEST_P(csyr2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syr2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_syr2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class csyr2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - scomplex alpha = std::get<6>(str.param); - scomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "csyr2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_csyr2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_csyr2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - csyr2kTest, + csyr2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), - ::csyr2kTestPrint() + ::syr2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp index f990ef6ac3..c38a317da9 100644 --- a/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2k.h" -class dsyr2kTest : +class dsyr2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(dsyr2kTest, RandomData) +TEST_P( dsyr2kGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -63,9 +63,9 @@ TEST_P(dsyr2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -79,72 +79,46 @@ TEST_P(dsyr2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syr2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_syr2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class dsyr2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - double alpha = std::get<6>(str.param); - double beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsyr2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsyr2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsyr2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dsyr2kTest, + dsyr2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of c ), - ::dsyr2kTestPrint() + ::syr2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp index 4b4cc8ccdd..2273dfc913 100644 --- a/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2k.h" -class ssyr2kTest : +class ssyr2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(ssyr2kTest, RandomData) +TEST_P( ssyr2kGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -63,9 +63,9 @@ TEST_P(ssyr2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -79,72 +79,46 @@ TEST_P(ssyr2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 10*m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syr2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_syr2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class ssyr2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - float alpha = std::get<6>(str.param); - float beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssyr2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssyr2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssyr2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ssyr2kTest, + ssyr2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of c ), - ::ssyr2kTestPrint() + ::syr2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/syr2k.h b/gtestsuite/testsuite/level3/syr2k/syr2k.h index 58b59923e5..5f64129197 100644 --- a/gtestsuite/testsuite/level3/syr2k/syr2k.h +++ b/gtestsuite/testsuite/level3/syr2k/syr2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -47,7 +48,7 @@ the matrix multiplication * @param[in] transb specifies the form of op( B ) to be used in the matrix multiplication - * @param[in] m specifies the number of rows and cols of the matrix + * @param[in] n specifies the number of rows and cols of the matrix op( A ) and rows of the matrix C and B * @param[in] k specifies the number of columns of the matrix op( B ) and the number of columns of the matrix C @@ -65,24 +66,40 @@ */ template -static void syr2k_(char uplo, char transa, gtint_t m, gtint_t k, T* alpha, +static void syr2k_(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { if constexpr (std::is_same::value) - ssyr2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + ssyr2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else if constexpr (std::is_same::value) - dsyr2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + dsyr2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else if constexpr (std::is_same::value) - csyr2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + csyr2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else if constexpr (std::is_same::value) - zsyr2k_( &uplo, &transa, &m, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + zsyr2k_( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); else throw std::runtime_error("Error in testsuite/level3/syr2k.h: Invalid typename in syr2k_()."); } +template +static void syr2k_blis_impl(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + ssyr2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + dsyr2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + csyr2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zsyr2k_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, bp, &ldb, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/syr2k.h: Invalid typename in syr2k_blis_impl()."); +} + template static void cblas_syr2k(char storage, char uplo, char transa, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc) { enum CBLAS_ORDER cblas_order; @@ -94,20 +111,20 @@ static void cblas_syr2k(char storage, char uplo, char transa, testinghelpers::char_to_cblas_trans( transa, &cblas_transa ); if constexpr (std::is_same::value) - cblas_ssyr2k( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, bp, ldb, *beta, cp, ldc ); + cblas_ssyr2k( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, bp, ldb, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_dsyr2k( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, bp, ldb, *beta, cp, ldc ); + cblas_dsyr2k( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, bp, ldb, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_csyr2k( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + cblas_csyr2k( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_zsyr2k( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + cblas_zsyr2k( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/syr2k.h: Invalid typename in cblas_syr2k()."); } template static void typed_syr2k(char storage, char uplo, char trnsa, char trnsb, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc) { trans_t transa, transb; @@ -122,7 +139,7 @@ static void typed_syr2k(char storage, char uplo, char trnsa, char trnsb, rsa=rsb=rsc=1; csa=csb=csc=1; - /* a = m x k b = k x n c = m x n */ + /* a = n x k b = k x n c = n x n */ if( (storage == 'c') || (storage == 'C') ) { csa = lda ; csb = ldb ; @@ -135,32 +152,113 @@ static void typed_syr2k(char storage, char uplo, char trnsa, char trnsb, } if constexpr (std::is_same::value) - bli_ssyr2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_ssyr2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_dsyr2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_dsyr2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_csyr2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_csyr2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_zsyr2k( blis_uplo, transa, transb, m, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); + bli_zsyr2k( blis_uplo, transa, transb, n, k, alpha, ap, rsa, csa, bp, rsb, csb, beta, cp, rsc, csc ); else throw std::runtime_error("Error in testsuite/level3/syr2k.h: Invalid typename in typed_syr2k()."); } template -static void syr2k( char storage, char uplo, char transa, char transb, gtint_t m, gtint_t k, +static void syr2k( char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uplo = static_cast(std::toupper(static_cast(uplo))); + transa = static_cast(std::toupper(static_cast(transa))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uplo_cpy = uplo; + char transa_cpy = transa; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, n, k, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) - syr2k_( uplo, transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + syr2k_( uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/syr2k.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + syr2k_blis_impl( uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/syr2k.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS - cblas_syr2k( storage, uplo, transa, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + cblas_syr2k( storage, uplo, transa, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #elif TEST_BLIS_TYPED - typed_syr2k( storage, uplo, transa, transb, m, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + typed_syr2k( storage, uplo, transa, transb, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); #else throw std::runtime_error("Error in testsuite/level3/syr2k.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, n, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, n, k, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, k, n, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/syr2k/test_syr2k.h b/gtestsuite/testsuite/level3/syr2k/test_syr2k.h index da2dabb0a9..01d2334b3b 100644 --- a/gtestsuite/testsuite/level3/syr2k/test_syr2k.h +++ b/gtestsuite/testsuite/level3/syr2k/test_syr2k.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -41,24 +41,31 @@ #include template -void test_syr2k( char storage, char uplo, char transa, char transb, gtint_t m, +void test_syr2k( char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, T beta, double thresh ) { // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); - gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, k, ldb_inc ); - gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, n, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); - std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, k, ldb ); - // Since matrix C, stored in c, is symmetric and we only use the upper or lower - // part in the computation of her2k and zero-out the rest to ensure - // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc ); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, n, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, n, k, ldb ); + std::vector c( testinghelpers::matsize( storage, 'n', n, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + // Since matrix C, stored in c, is symmetric and we only use the upper or lower + // part in the computation of her2k and zero-out the rest to ensure + // that code operates as expected. + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, uplo, n, c.data(), ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, n, c.data(), uplo, ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -66,17 +73,59 @@ void test_syr2k( char storage, char uplo, char transa, char transb, gtint_t m, //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - syr2k( storage, uplo, transa, transb, m, k, &alpha, a.data(), lda, + syr2k( storage, uplo, transa, transb, n, k, &alpha, a.data(), lda, b.data(), ldb, &beta, c.data(), ldc ); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_syr2k( storage, uplo, transa, transb, m, k, alpha, + testinghelpers::ref_syr2k( storage, uplo, transa, transb, n, k, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, m, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, n, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class syr2kGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uplo = std::get<1>(str.param); + char transa = std::get<2>(str.param); + char transb = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + T alpha = std::get<6>(str.param); + T beta = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + gtint_t ldc_inc = std::get<10>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, n, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp index 3600872367..0066895e56 100644 --- a/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syr2k.h" -class zsyr2kTest : +class zsyr2kGeneric : public ::testing::TestWithParam> {}; -TEST_P(zsyr2kTest, RandomData) +TEST_P( zsyr2kGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -63,9 +63,9 @@ TEST_P(zsyr2kTest, RandomData) char transa = std::get<2>(GetParam()); // denotes whether matrix b is n,c,t,h char transb = std::get<3>(GetParam()); - // matrix size m - gtint_t m = std::get<4>(GetParam()); // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k gtint_t k = std::get<5>(GetParam()); // specifies alpha value T alpha = std::get<6>(GetParam()); @@ -79,74 +79,47 @@ TEST_P(zsyr2kTest, RandomData) gtint_t ldc_inc = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syr2k.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (6*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); + test_syr2k( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class zsyr2kTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - char tsb = std::get<3>(str.param); - gtint_t m = std::get<4>(str.param); - gtint_t k = std::get<5>(str.param); - dcomplex alpha = std::get<6>(str.param); - dcomplex beta = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); - gtint_t ldc_inc = std::get<10>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zsyr2k_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zsyr2k"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zsyr2k"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa + tsb; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zsyr2kTest, + zsyr2kGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa ::testing::Values('n'), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}), // alpha ::testing::Values(dcomplex{-3.0, 2.0}, dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of b ::testing::Values(gtint_t(0), gtint_t(6)) // increment to the leading dim of c ), - ::zsyr2kTestPrint() + ::syr2kGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp index c876843931..479fad0f1a 100644 --- a/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syrk.h" -class csyrkTest : +class csyrkGeneric : public ::testing::TestWithParam> {}; -TEST_P(csyrkTest, RandomData) +TEST_P( csyrkGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -59,8 +59,8 @@ TEST_P(csyrkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -74,69 +74,45 @@ TEST_P(csyrkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syrk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_syrk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class csyrkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - scomplex alpha = std::get<5>(str.param); - scomplex beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "csyrk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_csyrk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_csyrk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - csyrkTest, + csyrkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','t'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), - ::csyrkTestPrint() + ::syrkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp index 05f1dc0229..5e62c7f2b2 100644 --- a/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syrk.h" -class dsyrkTest : +class dsyrkGeneric : public ::testing::TestWithParam> {}; -TEST_P(dsyrkTest, RandomData) +TEST_P( dsyrkGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -59,8 +59,8 @@ TEST_P(dsyrkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -74,67 +74,44 @@ TEST_P(dsyrkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syrk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_syrk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class dsyrkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - double alpha = std::get<5>(str.param); - double beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dsyrk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dsyrk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dsyrk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dsyrkTest, + dsyrkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','t','c'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), - ::dsyrkTestPrint() + ::syrkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp index 6ce9ab89bf..5e202521e7 100644 --- a/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syrk.h" -class ssyrkTest : +class ssyrkGeneric : public ::testing::TestWithParam> {}; -TEST_P(ssyrkTest, RandomData) +TEST_P( ssyrkGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -59,8 +59,8 @@ TEST_P(ssyrkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -74,67 +74,44 @@ TEST_P(ssyrkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syrk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_syrk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class ssyrkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - float alpha = std::get<5>(str.param); - float beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ssyrk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ssyrk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ssyrk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ssyrkTest, + ssyrkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','t','c'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), - ::ssyrkTestPrint() + ::syrkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/syrk.h b/gtestsuite/testsuite/level3/syrk/syrk.h index ecbea4725e..bcf70e05f5 100644 --- a/gtestsuite/testsuite/level3/syrk/syrk.h +++ b/gtestsuite/testsuite/level3/syrk/syrk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -60,24 +61,40 @@ */ template -static void syrk_(char uplo, char transa, gtint_t m, gtint_t k, T* alpha, +static void syrk_(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* beta, T* cp, gtint_t ldc ) { if constexpr (std::is_same::value) - ssyrk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + ssyrk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else if constexpr (std::is_same::value) - dsyrk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + dsyrk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else if constexpr (std::is_same::value) - csyrk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + csyrk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else if constexpr (std::is_same::value) - zsyrk_( &uplo, &transa, &m, &k, alpha, ap, &lda, beta, cp, &ldc ); + zsyrk_( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); else throw std::runtime_error("Error in testsuite/level3/syrk.h: Invalid typename in syrk_()."); } +template +static void syrk_blis_impl(char uplo, char transa, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* beta, T* cp, gtint_t ldc ) +{ + if constexpr (std::is_same::value) + ssyrk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + dsyrk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + csyrk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else if constexpr (std::is_same::value) + zsyrk_blis_impl( &uplo, &transa, &n, &k, alpha, ap, &lda, beta, cp, &ldc ); + else + throw std::runtime_error("Error in testsuite/level3/syrk.h: Invalid typename in syrk_blis_impl()."); +} + template static void cblas_syrk(char storage, char uplo, char trnsa, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* beta, T* cp, gtint_t ldc) { enum CBLAS_ORDER cblas_order; @@ -89,20 +106,20 @@ static void cblas_syrk(char storage, char uplo, char trnsa, testinghelpers::char_to_cblas_trans( trnsa, &cblas_transa ); if constexpr (std::is_same::value) - cblas_ssyrk( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, *beta, cp, ldc ); + cblas_ssyrk( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_dsyrk( cblas_order, cblas_uplo, cblas_transa, m, k, *alpha, ap, lda, *beta, cp, ldc ); + cblas_dsyrk( cblas_order, cblas_uplo, cblas_transa, n, k, *alpha, ap, lda, *beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_csyrk( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, beta, cp, ldc ); + cblas_csyrk( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, beta, cp, ldc ); else if constexpr (std::is_same::value) - cblas_zsyrk( cblas_order, cblas_uplo, cblas_transa, m, k, alpha, ap, lda, beta, cp, ldc ); + cblas_zsyrk( cblas_order, cblas_uplo, cblas_transa, n, k, alpha, ap, lda, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/syrk.h: Invalid typename in cblas_syrk()."); } template static void typed_syrk(char storage, char uplo, char trnsa, - gtint_t m, gtint_t k, T* alpha, T* ap, gtint_t lda, + gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* beta, T* cp, gtint_t ldc) { trans_t transa; @@ -115,7 +132,7 @@ static void typed_syrk(char storage, char uplo, char trnsa, rsa=rsc=1; csa=csc=1; - /* a = m x k c = m x m */ + /* a = n x k c = n x n */ if( (storage == 'c') || (storage == 'C') ) { csa = lda ; csc = ldc ; @@ -126,31 +143,94 @@ static void typed_syrk(char storage, char uplo, char trnsa, } if constexpr (std::is_same::value) - bli_ssyrk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_ssyrk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_dsyrk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_dsyrk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_csyrk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_csyrk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else if constexpr (std::is_same::value) - bli_zsyrk( blis_uplo, transa, m, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); + bli_zsyrk( blis_uplo, transa, n, k, alpha, ap, rsa, csa, beta, cp, rsc, csc ); else throw std::runtime_error("Error in testsuite/level3/syrk.h: Invalid typename in typed_syrk()."); } template -static void syrk( char storage, char uplo, char transa, gtint_t m, gtint_t k, +static void syrk( char storage, char uplo, char transa, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* beta, T* cp, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + uplo = static_cast(std::toupper(static_cast(uplo))); + transa = static_cast(std::toupper(static_cast(transa))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char uplo_cpy = uplo; + char transa_cpy = transa; + gtint_t n_cpy = n; + gtint_t k_cpy = k; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, n, k, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) - syrk_( uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + syrk_( uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); else throw std::runtime_error("Error in testsuite/level3/syrk.h: BLAS interface cannot be tested for row-major order."); +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + syrk_blis_impl( uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/syrk.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS - cblas_syrk( storage, uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + cblas_syrk( storage, uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); #elif TEST_BLIS_TYPED - typed_syrk( storage, uplo, transa, m, k, alpha, ap, lda, beta, cp, ldc ); + typed_syrk( storage, uplo, transa, n, k, alpha, ap, lda, beta, cp, ldc ); #else throw std::runtime_error("Error in testsuite/level3/syrk.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "uplo", uplo, uplo_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "n", n, n_cpy ); + computediff( "k", k, k_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + if(( transa == 'n' ) || ( transa == 'N' )) + computediff( "A", storage, n, k, ap, ap_cpy, lda, true ); + else + computediff( "A", storage, k, n, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/syrk/test_syrk.h b/gtestsuite/testsuite/level3/syrk/test_syrk.h index 464f608827..5a5bfd2dc1 100644 --- a/gtestsuite/testsuite/level3/syrk/test_syrk.h +++ b/gtestsuite/testsuite/level3/syrk/test_syrk.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -41,36 +41,83 @@ #include template -void test_syrk( char storage, char uplo, char transa, gtint_t m, gtint_t k, +void test_syrk( char storage, char uplo, char transa, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldc_inc, T alpha, T beta, double thresh ) { // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); - gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); - // Since matrix C, stored in c, is symmetric, we only use the upper or lower - // part in the computation of syrk and zero-out the rest to ensure - // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, uplo, m, ldc ); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, n, k, lda ); + std::vector c( testinghelpers::matsize( storage, 'n', n, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + // Since matrix C, stored in c, is symmetric, we only use the upper or lower + // part in the computation of syrk and zero-out the rest to ensure + // that code operates as expected. + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, uplo, n, c.data(), ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, n, c.data(), uplo, ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of c so that we can check reference results. std::vector c_ref(c); + //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - syrk( storage, uplo, transa, m, k, &alpha, a.data(), lda, + syrk( storage, uplo, transa, n, k, &alpha, a.data(), lda, &beta, c.data(), ldc ); + //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_syrk( storage, uplo, transa, m, k, alpha, + testinghelpers::ref_syrk( storage, uplo, transa, n, k, alpha, a.data(), lda, beta, c_ref.data(), ldc ); + //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, m, c.data(), c_ref.data(), ldc, thresh ); + computediff( "C", storage, n, n, c.data(), c_ref.data(), ldc, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class syrkGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char uplo = std::get<1>(str.param); + char transa = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + gtint_t k = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + T beta = std::get<6>(str.param); + gtint_t lda_inc = std::get<7>(str.param); + gtint_t ldc_inc = std::get<8>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uplo_" + std::string(&uplo, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp index 406d137d43..febeb3e459 100644 --- a/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_syrk.h" -class zsyrkTest : +class zsyrkGeneric : public ::testing::TestWithParam> {}; -TEST_P(zsyrkTest, RandomData) +TEST_P( zsyrkGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -59,8 +59,8 @@ TEST_P(zsyrkTest, RandomData) char uplo = std::get<1>(GetParam()); // denotes whether matrix a is n,c,t,h char transa = std::get<2>(GetParam()); - // matrix size m - gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<3>(GetParam()); // matrix size k gtint_t k = std::get<4>(GetParam()); // specifies alpha value @@ -74,69 +74,45 @@ TEST_P(zsyrkTest, RandomData) gtint_t ldc_inc = std::get<8>(GetParam()); // Set the threshold for the errors: - double thresh = m*k*testinghelpers::getEpsilon(); + // Check gtestsuite syrk.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); + test_syrk( storage, uplo, transa, n, k, lda_inc, ldc_inc, alpha, beta, thresh ); } -class zsyrkTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char uplo = std::get<1>(str.param); - char tsa = std::get<2>(str.param); - gtint_t m = std::get<3>(str.param); - gtint_t k = std::get<4>(str.param); - dcomplex alpha = std::get<5>(str.param); - dcomplex beta = std::get<6>(str.param); - gtint_t lda_inc = std::get<7>(str.param); - gtint_t ldc_inc = std::get<8>(str.param); -#ifdef TEST_BLAS - std::string str_name = "zsyrk_"; -#elif TEST_CBLAS - std::string str_name = "cblas_zsyrk"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_zsyrk"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + uplo; - str_name = str_name + "_" + tsa; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - zsyrkTest, + zsyrkGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n','t'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Range(gtint_t(10), gtint_t(31), 10), // k ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}), // alpha ::testing::Values(dcomplex{-3.0, 2.0}, dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), - ::zsyrkTestPrint() + ::syrkGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp index 5887027a58..3e71a8dd74 100644 --- a/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm.h" -class ctrmmTest : +class ctrmmGeneric : public ::testing::TestWithParam> {}; -TEST_P(ctrmmTest, RandomData) +TEST_P( ctrmmGeneric, API ) { using T = scomplex; //---------------------------------------------------------- @@ -78,7 +78,18 @@ TEST_P(ctrmmTest, RandomData) gtint_t ldb_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,48 +97,13 @@ TEST_P(ctrmmTest, RandomData) test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } -class ctrmmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - scomplex alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ctrmm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ctrmm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ctrmm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ctrmmTest, + ctrmmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -141,5 +117,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), - ::ctrmmTestPrint() + ::trmmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp index 1c9c251bdf..062fed57a2 100644 --- a/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm.h" -class dtrmmTest : +class dtrmmGeneric : public ::testing::TestWithParam> {}; -TEST_P(dtrmmTest, RandomData) +TEST_P( dtrmmGeneric, API ) { using T = double; //---------------------------------------------------------- @@ -78,7 +78,17 @@ TEST_P(dtrmmTest, RandomData) gtint_t ldb_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,47 +96,13 @@ TEST_P(dtrmmTest, RandomData) test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } -class dtrmmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - double alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dtrmm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dtrmm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dtrmm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dtrmmTest, + dtrmmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -140,5 +116,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), - ::dtrmmTestPrint() + ::trmmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp index 6851e1f52c..4815898ca5 100644 --- a/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm.h" -class strmmTest : +class strmmGeneric : public ::testing::TestWithParam> {}; -TEST_P(strmmTest, RandomData) +TEST_P( strmmGeneric, API ) { using T = float; //---------------------------------------------------------- @@ -78,7 +78,17 @@ TEST_P(strmmTest, RandomData) gtint_t ldb_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 20*m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,47 +96,13 @@ TEST_P(strmmTest, RandomData) test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } -class strmmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - float alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "strmm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_strmm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_strmm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - strmmTest, + strmmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -140,5 +116,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), - ::strmmTestPrint() + ::trmmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/test_trmm.h b/gtestsuite/testsuite/level3/trmm/test_trmm.h index 4ba801d937..7334cb5739 100644 --- a/gtestsuite/testsuite/level3/trmm/test_trmm.h +++ b/gtestsuite/testsuite/level3/trmm/test_trmm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -53,9 +53,16 @@ void test_trmm( char storage, char side, char uploa, char transa, char diaga, // Initialize matrics with random values. //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, mn, mn, lda ); - std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, 'n', m, n, ldb ); + std::vector b( testinghelpers::matsize( storage, 'n', m, n, ldb ) ); + if (alpha != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, m, n, b.data(), 'n', ldb ); + else + { + // Matrix B should not be read, only set. + testinghelpers::set_matrix( storage, m, n, b.data(), 'n', ldb, testinghelpers::aocl_extreme() ); + } - // Create a copy of v so that we can check reference results. + // Create a copy of b so that we can check reference results. std::vector b_ref(b); testinghelpers::make_triangular( storage, uploa, mn, a.data(), lda ); @@ -72,5 +79,46 @@ void test_trmm( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, b.data(), b_ref.data(), ldb, thresh ); + computediff( "B", storage, m, n, b.data(), b_ref.data(), ldb, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class trmmGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char transa = std::get<3>(str.param); + char diaga = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T alpha = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diaga_" + std::string(&diaga, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/trmm/trmm.h b/gtestsuite/testsuite/level3/trmm/trmm.h index 267aa41e7e..958bf7171c 100644 --- a/gtestsuite/testsuite/level3/trmm/trmm.h +++ b/gtestsuite/testsuite/level3/trmm/trmm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,12 +36,13 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: - * op( A )*X = alpha*B, or X*op( A ) = alpha*B, + * B := alpha*op( A )*B, or B := alpha*B*op( A ) * where op( A ) is one of - * op( A ) = A or op( A ) = A**T, + * op( A ) = A or op( A ) = A**T or op( A ) = A**H, * @param[in] storage specifies storage format used for the matrices * @param[in] side specifies if the symmetric matrix A appears left or right in the matrix multiplication @@ -78,6 +79,22 @@ static void trmm_( char side, char uploa, char transa, char diaga, gtint_t m, throw std::runtime_error("Error in testsuite/level3/trmm.h: Invalid typename in trmm_()."); } +template +static void trmm_blis_impl( char side, char uploa, char transa, char diaga, gtint_t m, + gtint_t n, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb ) +{ + if constexpr (std::is_same::value) + strmm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + dtrmm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + ctrmm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + ztrmm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else + throw std::runtime_error("Error in testsuite/level3/trmm.h: Invalid typename in trmm_blis_impl()."); +} + template static void cblas_trmm( char storage, char side, char uploa, char transa, char diaga, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, @@ -154,12 +171,50 @@ template static void trmm( char storage, char side, char uploa, char transa, char diaga, gtint_t m, gtint_t n, T *alpha, T *ap, gtint_t lda, T *bp, gtint_t ldb ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + side = static_cast(std::toupper(static_cast(side))); + uploa = static_cast(std::toupper(static_cast(uploa))); + transa = static_cast(std::toupper(static_cast(transa))); + diaga = static_cast(std::toupper(static_cast(diaga))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char side_cpy = side; + char uploa_cpy = uploa; + char transa_cpy = transa; + char diaga_cpy = diaga; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, mn, mn, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) trmm_( side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); else throw std::runtime_error("Error in testsuite/level3/trmm.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + trmm_blis_impl( side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); + else + throw std::runtime_error("Error in testsuite/level3/trmm.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_trmm( storage, side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); #elif TEST_BLIS_TYPED @@ -167,4 +222,31 @@ static void trmm( char storage, char side, char uploa, char transa, char diaga, #else throw std::runtime_error("Error in testsuite/level3/trmm.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "side", side, side_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "diaga", diaga, diaga_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, mn, mn, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp index d6ad3e02ca..138e2a0187 100644 --- a/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm.h" -class ztrmmTest : +class ztrmmGeneric : public ::testing::TestWithParam> {}; -TEST_P(ztrmmTest, RandomData) +TEST_P( ztrmmGeneric, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -78,7 +78,18 @@ TEST_P(ztrmmTest, RandomData) gtint_t ldb_inc = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -86,48 +97,13 @@ TEST_P(ztrmmTest, RandomData) test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } -class ztrmmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - dcomplex alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ztrmm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ztrmm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ztrmm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ztrmmTest, + ztrmmGeneric, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -141,5 +117,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of b ), - ::ztrmmTestPrint() + ::trmmGenericPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp index 839c472988..2fc7174472 100644 --- a/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm3.h" -class ctrmm3Test : +class ctrmm3Generic : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ctrmm3Test); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ctrmm3Generic); -TEST_P(ctrmm3Test, RandomData) +TEST_P( ctrmm3Generic, API ) { using T = scomplex; //---------------------------------------------------------- @@ -88,7 +88,21 @@ TEST_P(ctrmm3Test, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm3.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -96,47 +110,11 @@ TEST_P(ctrmm3Test, RandomData) test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } -class ctrmm3TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char transb = std::get<4>(str.param); - char diaga = std::get<5>(str.param); - gtint_t m = std::get<6>(str.param); - gtint_t n = std::get<7>(str.param); - scomplex alpha = std::get<8>(str.param); - scomplex beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); - std::string str_name = "bli_ctrmm3"; - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa + transb; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ctrmm3Test, + ctrmm3Generic, ::testing::Combine( ::testing::Values('c','r'), // storage format ::testing::Values('l','r'), // side l:left, r:right @@ -152,6 +130,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::ctrmm3TestPrint() + ::trmm3GenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp index 343a573666..17a1de4a87 100644 --- a/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm3.h" -class dtrmm3Test : +class dtrmm3Generic : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dtrmm3Test); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dtrmm3Generic); -TEST_P(dtrmm3Test, RandomData) +TEST_P( dtrmm3Generic, API ) { using T = double; //---------------------------------------------------------- @@ -88,7 +88,20 @@ TEST_P(dtrmm3Test, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm3.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -96,45 +109,11 @@ TEST_P(dtrmm3Test, RandomData) test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } -class dtrmm3TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char transb = std::get<4>(str.param); - char diaga = std::get<5>(str.param); - gtint_t m = std::get<6>(str.param); - gtint_t n = std::get<7>(str.param); - double alpha = std::get<8>(str.param); - double beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); - std::string str_name = "bli_dtrmm3"; - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa + transb; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - dtrmm3Test, + dtrmm3Generic, ::testing::Combine( ::testing::Values('c','r'), // storage format ::testing::Values('l','r'), // side l:left, r:right @@ -150,6 +129,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::dtrmm3TestPrint() + ::trmm3GenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp index 2d52b620e8..7de8bcee70 100644 --- a/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm3.h" -class strmm3Test : +class strmm3Generic : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(strmm3Test); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(strmm3Generic); -TEST_P(strmm3Test, RandomData) +TEST_P( strmm3Generic, API ) { using T = float; //---------------------------------------------------------- @@ -88,7 +88,20 @@ TEST_P(strmm3Test, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm3.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -96,45 +109,11 @@ TEST_P(strmm3Test, RandomData) test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } -class strmm3TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char transb = std::get<4>(str.param); - char diaga = std::get<5>(str.param); - gtint_t m = std::get<6>(str.param); - gtint_t n = std::get<7>(str.param); - float alpha = std::get<8>(str.param); - float beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); - std::string str_name = "bli_strmm3"; - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa + transb; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - strmm3Test, + strmm3Generic, ::testing::Combine( ::testing::Values('c','r'), // storage format ::testing::Values('l','r'), // side l:left, r:right @@ -150,6 +129,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::strmm3TestPrint() + ::trmm3GenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/trmm3/test_trmm3.h b/gtestsuite/testsuite/level3/trmm3/test_trmm3.h index 8203a0cb6b..1371f779da 100644 --- a/gtestsuite/testsuite/level3/trmm3/test_trmm3.h +++ b/gtestsuite/testsuite/level3/trmm3/test_trmm3.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -56,7 +56,14 @@ void test_trmm3( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, mn, mn, lda ); std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + std::vector c( testinghelpers::matsize( storage, 'n', m, n, ldc ) ); + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, c.data(), 'n', ldc ); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, c.data(), 'n', ldc, testinghelpers::aocl_extreme() ); + } // Create a copy of v so that we can check reference results. std::vector c_ref(c); @@ -76,5 +83,53 @@ void test_trmm3( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldb, thresh ); + computediff( "C", storage, m, n, c.data(), c_ref.data(), ldb, thresh ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class trmm3GenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char transa = std::get<3>(str.param); + char transb = std::get<4>(str.param); + char diaga = std::get<5>(str.param); + gtint_t m = std::get<6>(str.param); + gtint_t n = std::get<7>(str.param); + T alpha = std::get<8>(str.param); + T beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_diaga_" + std::string(&diaga, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/trmm3/trmm3.h b/gtestsuite/testsuite/level3/trmm3/trmm3.h index 2bd52db11a..3fa865aff4 100644 --- a/gtestsuite/testsuite/level3/trmm3/trmm3.h +++ b/gtestsuite/testsuite/level3/trmm3/trmm3.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: @@ -126,14 +127,100 @@ static void trmm3( char storage, char side, char uploa, char transa, char diaga, char transb, gtint_t m, gtint_t n, T *alpha, T *ap, gtint_t lda, T *bp, gtint_t ldb, T *beta, T *c, gtint_t ldc ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + side = static_cast(std::toupper(static_cast(side))); + uploa = static_cast(std::toupper(static_cast(uploa))); + transa = static_cast(std::toupper(static_cast(transa))); + diaga = static_cast(std::toupper(static_cast(diaga))); + transb = static_cast(std::toupper(static_cast(transb))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char side_cpy = side; + char uploa_cpy = uploa; + char transa_cpy = transa; + char diaga_cpy = diaga; + char transb_cpy = transb; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + T* beta_cpy = beta; + gtint_t ldc_cpy = ldc; + + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, mn, mn, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } + T* bp_cpy = nullptr; + gtint_t size_bp = testinghelpers::matsize( storage, transb, m, n, ldb ); + if (bp && size_bp > 0) + { + bp_cpy = new T[size_bp]; + memcpy( bp_cpy, bp, size_bp * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS throw std::runtime_error("Error in testsuite/level3/trmm3.h: BLAS interface is not available."); +#elif TEST_BLAS_BLIS_IMPL + throw std::runtime_error("Error in testsuite/level3/trmm3.h: BLAS_BLIS_IMPL interface is not available."); #elif TEST_CBLAS - throw std::runtime_error("Error in testsuite/level3/trmm3.h: BLAS interface is not available."); + throw std::runtime_error("Error in testsuite/level3/trmm3.h: CBLAS interface is not available."); #elif TEST_BLIS_TYPED typed_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, ap, lda, bp, ldb, beta, c, ldc ); #else throw std::runtime_error("Error in testsuite/level3/trmm3.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "side", side, side_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "diaga", diaga, diaga_cpy ); + computediff( "transb", transb, transb_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + if (beta) computediff( "beta", *beta, *beta_cpy ); + computediff( "ldc", ldc, ldc_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, mn, mn, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } + + if (bp && size_bp > 0) + { + if(( transb == 'n' ) || ( transb == 'N' )) + computediff( "B", storage, m, n, bp, bp_cpy, ldb, true ); + else + computediff( "B", storage, n, m, bp, bp_cpy, ldb, true ); + delete[] bp_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp index 6ef3931d72..31e7c12e65 100644 --- a/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,7 +35,7 @@ #include #include "test_trmm3.h" -class ztrmm3Test : +class ztrmm3Generic : public ::testing::TestWithParam> {}; -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ztrmm3Test); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ztrmm3Generic); -TEST_P(ztrmm3Test, RandomData) +TEST_P( ztrmm3Generic, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -88,7 +88,21 @@ TEST_P(ztrmm3Test, RandomData) gtint_t ldc_inc = std::get<12>(GetParam()); // Set the threshold for the errors: - double thresh = m*n*testinghelpers::getEpsilon(); + // Check gtestsuite trmm3.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() && + (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE())) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = (3*m+1)*testinghelpers::getEpsilon(); + else + thresh = (3*n+1)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -96,47 +110,11 @@ TEST_P(ztrmm3Test, RandomData) test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } -class ztrmm3TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char transb = std::get<4>(str.param); - char diaga = std::get<5>(str.param); - gtint_t m = std::get<6>(str.param); - gtint_t n = std::get<7>(str.param); - dcomplex alpha = std::get<8>(str.param); - dcomplex beta = std::get<9>(str.param); - gtint_t lda_inc = std::get<10>(str.param); - gtint_t ldb_inc = std::get<11>(str.param); - gtint_t ldc_inc = std::get<12>(str.param); - std::string str_name = "bli_ztrmm3"; - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa + transb; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + std::to_string(ldc_inc); - return str_name; - } -}; - #ifdef TEST_BLIS_TYPED // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ztrmm3Test, + ztrmm3Generic, ::testing::Combine( ::testing::Values('c','r'), // storage format ::testing::Values('l','r'), // side l:left, r:right @@ -152,6 +130,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(0)), // increment to the leading dim of b ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), - ::ztrmm3TestPrint() + ::trmm3GenericPrint() ); #endif diff --git a/gtestsuite/testsuite/level3/trsm/IIT_ERS/trsm_IIT_ERS.cpp b/gtestsuite/testsuite/level3/trsm/IIT_ERS/trsm_IIT_ERS.cpp new file mode 100644 index 0000000000..ed1f14c8b6 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/IIT_ERS/trsm_IIT_ERS.cpp @@ -0,0 +1,480 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "level3/trsm/trsm.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" +#include "common/wrong_inputs_helpers.h" +#include +#include +#include + + +template +class trsm_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(trsm_IIT_ERS, TypeParam); + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#if defined(TEST_CBLAS) +#define INFO_OFFSET 1 +#else +#define INFO_OFFSET 0 +#endif + +#if defined(TEST_CBLAS) + +/** + * @brief Test TRSM when storage argument is incorrect + * when info == 1 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_storage) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsm( 'x', SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, nullptr, LDB); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( 'x', SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 1 ); +#endif +} + +#endif + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/** + * @brief Test TRSM when side argument is incorrect + * when info == 1 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_side) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, 'a', UPLO, TRANS, DIAG, M, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, 'a', UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, 'a', UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+1 ); +#endif +} + +/** + * @brief Test TRSM when UPLO argument is incorrect + * when info == 2 + * + */ +TYPED_TEST(trsm_IIT_ERS, invalid_UPLO) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, 'a', TRANS, DIAG, M, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, 'a', TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, 'a', TRANS, DIAG, M, N, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+2 ); +#endif +} + +/** + * @brief Test TRSM when TRANS argument is incorrect + * when info == 3 + * + */ +TYPED_TEST(trsm_IIT_ERS, invalid_TRANS) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, 'a', DIAG, M, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, 'a', DIAG, M, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, 'a', DIAG, M, N, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+3 ); +#endif +} + +/** + * @brief Test TRSM when DIAG argument is incorrect + * when info == 4 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_DIAG) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, 'a', M, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, 'a', M, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+4 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, 'a', M, N, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, INFO_OFFSET+4 ); +#endif +} + +/** + * @brief Test TRSM when m is negative + * when info == 5 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_m) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, -1, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, -1, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, -1, N, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 5 ); +#endif +} + +/** + * @brief Test TRSM when n is negative + * when info == 6 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_n) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, -1, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, -1, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, -1, &ALPHA, a.data(), LDA, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 6 ); +#endif +} + +/** + * @brief Test TRSM when lda is incorrect + * when info == 9 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_lda) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, nullptr, nullptr, LDA - 1, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA - 1, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, a.data(), LDA - 1, b.data(), LDB); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 9 ); +#endif +} + +/** + * @brief Test TRSM when ldb is incorrect + * when info == 11 + */ +TYPED_TEST(trsm_IIT_ERS, invalid_ldb) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, nullptr, nullptr, LDA, nullptr, LDB - 1); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, nullptr, LDB - 1); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 11 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, a.data(), LDA, b.data(), LDB - 1); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 11 ); +#endif +} + + +/* + Early Return Scenarios(ERS) : + + The TRSM API is expected to return early in the following cases: + + 1. When m == 0. + 2. When n == 0. + 3. When alpha == 0, set B to 0 only. + +*/ + +/** + * @brief Test TRSM when M is zero + */ +TYPED_TEST(trsm_IIT_ERS, m_eq_zero) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, 0, N, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, 0, N, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, 0, N, &ALPHA, a.data(), LDA, b.data(), LDB ); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +/** + * @brief Test TRSM when N is zero + */ +TYPED_TEST(trsm_IIT_ERS, n_eq_zero) +{ + using T = TypeParam; + T ALPHA = T{2.3}; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. +#if defined(TEST_BLAS_LIKE) + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, 0, nullptr, nullptr, LDA, nullptr, LDB); +#else + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, 0, &ALPHA, nullptr, LDA, nullptr, LDB); +#endif +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b_ref(b); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, 0, &ALPHA, a.data(), LDA, b.data(), LDB ); + computediff( "B", STORAGE, M, N, b.data(), b_ref.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +/** + * @brief Test TRSM when alpha is zero + */ +TYPED_TEST(trsm_IIT_ERS, alpha_eq_zero) +{ + using T = TypeParam; + T ALPHA; + testinghelpers::initzero( ALPHA ); + + std::vector b = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + std::vector b2(b); + std::vector zero_mat = testinghelpers::get_random_matrix(0, 0, STORAGE, 'n', M, N, LDB); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, nullptr, LDA, b2.data(), LDB); + computediff( "B", STORAGE, M, N, b2.data(), zero_mat.data(), LDB ); +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif + + // Test with all arguments correct except for the value we are choosing to test. + std::vector a = testinghelpers::get_random_matrix(0, 1, STORAGE, 'n', M, N, LDB); + + trsm( STORAGE, SIDE, UPLO, TRANS, DIAG, M, N, &ALPHA, a.data(), LDA, b.data(), LDB ); + computediff( "B", STORAGE, M, N, b.data(), zero_mat.data(), LDB ); + +#ifdef CAN_TEST_INFO_VALUE + info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif +} + +#endif diff --git a/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_evt.cpp b/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_evt.cpp new file mode 100644 index 0000000000..c452044a63 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_evt.cpp @@ -0,0 +1,169 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class ctrsmEVT : + public ::testing::TestWithParam> {}; // EVT test for B + + +TEST_P( ctrsmEVT, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + EVT_TYPE a_init = std::get<10>(GetParam()); + EVT_TYPE b_init = std::get<11>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init ); +} + +/** + * @brief Test CTRSM for extreme values + * Code paths taken for: + * TRSV -> 1 + * AVX2 Small -> 301, 324 + * Native -> 1051, 1176 + */ +INSTANTIATE_TEST_SUITE_P( + evt, + ctrsmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 301, 1051), // m + ::testing::Values(1, 324, 1176), // n + ::testing::Values(scomplex{-2.4, 2.0}, + scomplex{-0.0, 2.3}, + scomplex{-2.4, 0.0}, + scomplex{ 0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF, + NEG_INF, NEG_NaN), // EVT test for A + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, NEG_INF, NEG_NaN) // EVT test for B + ), + ::trsmEVTPrint() + ); + +/** + * @brief Test CTRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 3 + * TRSM_NATIVE -> 1001 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + ctrsmEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 3, 1001), // n + ::testing::Values(1, 3, 1001), // m + ::testing::Values(scomplex{NAN, -2.0}, + scomplex{-2.0, NAN}, + scomplex{INFINITY, 3.1f}, + scomplex{NAN, -INFINITY}), // alpha + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b + ::testing::Values(NO_EVT), // EVT test for A + ::testing::Values(NO_EVT) // EVT test for B + ), + ::trsmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_generic.cpp new file mode 100644 index 0000000000..1ec39b767f --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/ctrsm/ctrsm_generic.cpp @@ -0,0 +1,194 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class ctrsmGeneric : + public ::testing::TestWithParam> {}; // ldb_inc + +TEST_P( ctrsmGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); +} + +/** + * @brief Test CTRSM native path, which starts from size 1001 for BLAS api + * and starts from size 0 for BLIS api. + */ +INSTANTIATE_TEST_SUITE_P( + Native, + ctrsmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','c','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 112, 1200), // m + ::testing::Values(1, 154, 1317), // n + ::testing::Values(scomplex{2.0,-1.0}), // alpha + ::testing::Values(gtint_t(31)), // increment to the leading dim of a + ::testing::Values(gtint_t(45)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test CTRSM small avx2 path all fringe cases + * Kernel size for avx2 small path is 8x3, testing in range of + * 1 to 8 ensures all finge cases are being tested. + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2_fringe, + ctrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Range(gtint_t(1), gtint_t(9), 1), // m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // n + ::testing::Values(scomplex{2.0,-3.4}), // alpha + ::testing::Values(gtint_t(58)), // increment to the leading dim of a + ::testing::Values(gtint_t(32)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test CTRSM small avx2 path, this code path is used in range 0 to 1000 + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2, + ctrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(17, 1000), // m + ::testing::Values(48, 1000), // n + ::testing::Values(scomplex{2.0,-3.4}), // alpha + ::testing::Values(gtint_t(85)), // increment to the leading dim of a + ::testing::Values(gtint_t(33)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test CTRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 3 + * TRSM_NATIVE -> 1001 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + ctrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 3, 1001), // n + ::testing::Values(1, 3, 1001), // m + ::testing::Values(scomplex{2.0, 0.0}, scomplex{0.0, -10.0}, + scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(45)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(93)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp deleted file mode 100644 index 85c3917a39..0000000000 --- a/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp +++ /dev/null @@ -1,145 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsm.h" - -class ctrsmTest : - public ::testing::TestWithParam> {}; - -TEST_P(ctrsmTest, RandomData) -{ - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // specifies matrix A appears left or right in - // the matrix multiplication - char side = std::get<1>(GetParam()); - // specifies upper or lower triangular part of A is used - char uploa = std::get<2>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<3>(GetParam()); - // denotes whether matrix a in unit or non-unit diagonal - char diaga = std::get<4>(GetParam()); - // matrix size m - gtint_t m = std::get<5>(GetParam()); - // matrix size n - gtint_t n = std::get<6>(GetParam()); - // specifies alpha value - T alpha = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); -} - -class ctrsmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - scomplex alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ctrsm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ctrsm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ctrsm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ctrsmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('l','r'), // side l:left, r:right - ::testing::Values('u','l'), // uplo u:upper, l:lower - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=nonunit u=unit - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Range(gtint_t(10), gtint_t(31), 10), // n - ::testing::Values(scomplex{2.0,-1.0}), // alpha - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b - ), - ::ctrsmTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_evt.cpp similarity index 59% rename from gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp rename to gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_evt.cpp index 87b841defd..0bc0c2a0f7 100644 --- a/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp +++ b/gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_evt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -33,9 +33,9 @@ */ #include -#include "test_trsm.h" +#include "level3/trsm/test_trsm.h" -class dtrsmTest : +class dtrsmEVT : public ::testing::TestWithParam> {}; + gtint_t, + EVT_TYPE, + EVT_TYPE>> {}; + -TEST_P(dtrsmTest, RandomData) +TEST_P( dtrsmEVT, API ) { using T = double; //---------------------------------------------------------- @@ -77,56 +80,42 @@ TEST_P(dtrsmTest, RandomData) gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); + EVT_TYPE a_init = std::get<10>(GetParam()); + EVT_TYPE b_init = std::get<11>(GetParam()); + // Set the threshold for the errors: - double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init ); } -class dtrsmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - double alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dtrsm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dtrsm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dtrsm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - -// Black box testing. +/** + * @brief Test DTRSM for extreme values + * Code paths taken for: + * TRSV -> 1 + * AVX2 Small -> 2 + * AVX512 Small -> 301, 324 + * Native -> 1551, 1676 + */ INSTANTIATE_TEST_SUITE_P( - Blackbox, - dtrsmTest, + Native, + dtrsmEVT, ::testing::Combine( ::testing::Values('c' -#ifndef TEST_BLAS +#ifndef TEST_BLAS_LIKE ,'r' #endif ), // storage format @@ -134,11 +123,13 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('u','l'), // uplo u:upper, l:lower ::testing::Values('n','t'), // transa ::testing::Values('n','u'), // diaga , n=nonunit u=unit - ::testing::Range(gtint_t(10), gtint_t(11), 10), // m - ::testing::Range(gtint_t(10), gtint_t(11), 10), // n - ::testing::Values( 1.0, -2.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b + ::testing::Values(1, 2, 301, 1551), // m + ::testing::Values(1, 2, 324, 1676), // n + ::testing::Values(-2.4, 0), // alpha + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF),// EVT test for A + ::testing::Values(NO_EVT, NaN, INF, NaN_INF) // EVT test for B ), - ::dtrsmTestPrint() + ::trsmEVTPrint() ); diff --git a/gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_generic.cpp new file mode 100644 index 0000000000..05fa45c426 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/dtrsm/dtrsm_generic.cpp @@ -0,0 +1,240 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class dtrsmGeneric : + public ::testing::TestWithParam> {}; // ldb_inc + +TEST_P( dtrsmGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); +} + +/** + * @brief Test DTRSM native path, which starts from size 1500 for BLAS api + * and starts from size 0 for BLIS api. + */ +INSTANTIATE_TEST_SUITE_P( + Native, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 2, 112, 1551), // m + ::testing::Values(1, 2, 154, 1676), // n + ::testing::Values(-2.4), // alpha + ::testing::Values(gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test DTRSM small avx2 path all fringe cases + * Kernel size for avx2 small path is 6x8, testing in range of + * 1 to 8 ensures all finge cases are being tested. + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2_fringe, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Range(gtint_t(1), gtint_t(9), 1), // m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // n + ::testing::Values(-2.4), // alpha + ::testing::Values(gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test DTRSM small avx2 path which is used in + * range [0, 50] for genoa and [0, 1499] for milan + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(17, 110, 51, 1499), // m + ::testing::Values(17, 48 , 51, 1499), // n + ::testing::Values(-2.4), // alpha + ::testing::Values(gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test DTRSM small avx512 path all fringe cases + * small avx512 is used in range [51, 1499] + * Kernel size for avx512 small path is 8x8, therefore + * testing in range of 51 to 58 covers all fringe cases. + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX512_fringe, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Range(gtint_t(51), gtint_t(59), 1), // m + ::testing::Range(gtint_t(51), gtint_t(59), 1), // n + ::testing::Values(-2.4), // alpha + ::testing::Values(gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test DTRSM small avx512 path + * small avx512 is used in range [51, 1499] + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX512, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(51, 410, 1499), // n + ::testing::Values(51, 531, 1499), // m + ::testing::Values(-2.4), // alpha + ::testing::Values(gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test DTRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 2 + * TRSM_AVX512_small -> 300 + * TRSM_NATIVE -> 1500 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + dtrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 2, 300, 1500), // n + ::testing::Values(1, 2, 300, 1500), // m + ::testing::Values(-2.4, 0.0, 1.0, 3.1, NAN, INFINITY), // alpha + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/strsm/strsm_evt.cpp b/gtestsuite/testsuite/level3/trsm/strsm/strsm_evt.cpp new file mode 100644 index 0000000000..4c41ce7080 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/strsm/strsm_evt.cpp @@ -0,0 +1,159 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class strsmEVT : + public ::testing::TestWithParam> {}; // EVT type for B + + +TEST_P( strsmEVT, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + EVT_TYPE a_init = std::get<10>(GetParam()); + EVT_TYPE b_init = std::get<11>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init ); +} + +/** + * @brief Test STRSM for extreme values + * Code paths taken for: + * TRSV -> 1 + * AVX2 Small -> 301, 324 + * Native -> 1051, 1176 + */ +INSTANTIATE_TEST_SUITE_P( + Native, + strsmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 301, 1051), // m + ::testing::Values(1, 324, 1176), // n + ::testing::Values(-2.4, 0.0, 1.0, -1.0), // alpha + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF, + NEG_INF, NEG_NaN), // EVT test for A + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, NEG_INF, NEG_NaN) // EVT test for B + ), + ::trsmEVTPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + Alpha, + strsmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 301, 1051), // m + ::testing::Values(1, 324, 1176), // n + ::testing::Values(NAN, INFINITY, -INFINITY), // alpha + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(NO_EVT), // EVT test for A + ::testing::Values(NO_EVT) // EVT test for B + ), + ::trsmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/strsm/strsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/strsm/strsm_generic.cpp new file mode 100644 index 0000000000..4234277bd3 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/strsm/strsm_generic.cpp @@ -0,0 +1,194 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class strsmGeneric : + public ::testing::TestWithParam> {}; // ldb_inc + +TEST_P( strsmGeneric, API ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); +} + +/** + * @brief Test STRSM native path, which starts from size 1000 for BLAS api + * and starts from size 0 for BLIS api. + */ +INSTANTIATE_TEST_SUITE_P( + Native, + strsmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 2, 112, 1200), // m + ::testing::Values(1, 2, 154, 1317), // n + ::testing::Values(-2.0f), // alpha + ::testing::Values(gtint_t(45)), // increment to the leading dim of a + ::testing::Values(gtint_t(38)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test STRSM small avx2 path all fringe cases + * Kernel size for avx2 small path is 16x6, testing in range of + * 1 to 16 ensures all finge cases are being tested. + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2_fringe, + strsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Range(gtint_t(1), gtint_t(17), 1), // m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // n + ::testing::Values(-2.4f), // alpha + ::testing::Values(gtint_t(58)), // increment to the leading dim of a + ::testing::Values(gtint_t(31)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + + +/** + * @brief Test STRSM small avx2 path, this code path is used in range 0 to 1000 + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2, + strsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(17, 110, 51, 1000), // m + ::testing::Values(17, 48 , 51, 1000), // n + ::testing::Values(-2.4f), // alpha + ::testing::Values(gtint_t(95)), // increment to the leading dim of a + ::testing::Values(gtint_t(83)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + + +/** + * @brief Test STRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 3 + * TRSM_NATIVE -> 1001 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + strsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 3, 1001), // n + ::testing::Values(1, 3, 1001), // m + ::testing::Values(-2.4f, 0.0f, 1.0f, 3.1f), // alpha + ::testing::Values(gtint_t(0), gtint_t(35)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(39)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp deleted file mode 100644 index 2e197c104f..0000000000 --- a/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp +++ /dev/null @@ -1,144 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsm.h" - -class strsmTest : - public ::testing::TestWithParam> {}; - -TEST_P(strsmTest, RandomData) -{ - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // specifies matrix A appears left or right in - // the matrix multiplication - char side = std::get<1>(GetParam()); - // specifies upper or lower triangular part of A is used - char uploa = std::get<2>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<3>(GetParam()); - // denotes whether matrix a in unit or non-unit diagonal - char diaga = std::get<4>(GetParam()); - // matrix size m - gtint_t m = std::get<5>(GetParam()); - // matrix size n - gtint_t n = std::get<6>(GetParam()); - // specifies alpha value - T alpha = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); -} - -class strsmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - float alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "strsm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_strsm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_strsm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - strsmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('l','r'), // side l:left, r:right - ::testing::Values('u','l'), // uplo u:upper, l:lower - ::testing::Values('n','t'), // transa - ::testing::Values('n','u'), // diaga , n=nonunit u=unit - ::testing::Range(gtint_t(10), gtint_t(11), 10), // m - ::testing::Range(gtint_t(10), gtint_t(11), 10), // n - ::testing::Values( 1.0, -2.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b - ), - ::strsmTestPrint() - ); diff --git a/gtestsuite/testsuite/level3/trsm/test_trsm.h b/gtestsuite/testsuite/level3/trsm/test_trsm.h index df0502b060..ed07569c8b 100644 --- a/gtestsuite/testsuite/level3/trsm/test_trsm.h +++ b/gtestsuite/testsuite/level3/trsm/test_trsm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -37,12 +37,157 @@ #include "trsm.h" #include "level3/ref_trsm.h" #include "inc/check_error.h" +#include "common/testing_helpers.h" #include #include +// ENUM for extreme value testing +typedef enum +{ + ZERO, + NaN, + NEG_NaN, + INF, + NEG_INF, + NaN_INF, + DIAG_NaN, + DIAG_INF, + NO_EVT +} EVT_TYPE; + + +/** + * @brief Insert NaN/Inf in the matrix for extreme value testing + * + * @tparam T + * @param mat input matrix where NAN/Inf needs to be inserted + * @param uploa specify if input matrix in uppper or lower triangular + * @param m size of the input matrix + * @param ld leading dimension of input matrix + * @param type type of extreme value to be inserted ( EVT_TYPE ) + * @param is_a is the input matrix traingular( matrix A in TRSM ) + * @param is_diag insert extreme value in diagonal element + */ +template +void generate_NAN_INF( T* mat, char uploa, gtint_t m, gtint_t ld, EVT_TYPE type, bool is_a, bool is_diag = false) +{ + // RT contains the real type of T. + using RT = typename testinghelpers::type_info::real_type; + // inf_nan will contain either inf or nan depending on requirement + RT inf_nan = std::numeric_limits::quiet_NaN(); + + if(type == INF) + { + inf_nan = std::numeric_limits::infinity(); + } + else if (type == NEG_INF) + { + inf_nan = RT{-1} * std::numeric_limits::infinity(); + } + else if (type == NEG_NaN) + { + inf_nan = RT{-1} * std::numeric_limits::quiet_NaN(); + } + else // type == NaN + { + inf_nan = std::numeric_limits::quiet_NaN(); + } + + // exval will contain the exception value to be injected in the matrix. + T exval; + if constexpr ( testinghelpers::type_info::is_real ) exval = T{inf_nan}; + else exval = T{inf_nan, inf_nan}; + + // if size is one, then set the only element in matrix + // to inf or nan + if (m <= 1) + { + *(mat) = exval; + } + else + { + // get a random number in range of 1 to m; + gtint_t mn = (std::max)(gtint_t(1), gtint_t(rand()) % m); + if( uploa == 'l' || uploa == 'L') + { + // set one element to inf/nan in lower half of matrix + *(mat + mn + ((mn - (!is_diag)) * ld) ) = exval; + } + else + { + // set one element to inf/nan in upper half of matrix + *(mat + (mn - (!is_diag)) + (mn * ld) ) = exval; + } + } + + /* // Make All elements NaN\INF + // This test is commented out inorder to reduce the + // testing time. + // It is not needed to cover all the test cases, but + // it can be enabled in future if the need arises. + for (gtint_t i=0; i +void random_generator_with_INF_NAN( T* mat, char uploa, char storage, char trans, double from, double to, gtint_t m, +gtint_t n, gtint_t ld, EVT_TYPE type = NO_EVT, bool is_a = false ) +{ + switch( type ) + { + case ZERO: + testinghelpers::datagenerators::randomgenerators( 0, 0, storage, m, n, mat, ld); + break; + case NaN: + case INF: + testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, mat, ld); + generate_NAN_INF(mat, uploa, (std::min)(m, n), ld, type, is_a); + break; + case DIAG_INF: + case DIAG_NaN: + testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, mat, ld); + generate_NAN_INF(mat, uploa, (std::min)(m, n), ld, type, is_a, true); + break; + case NaN_INF: + testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, mat, ld); + generate_NAN_INF(mat, uploa, (std::min)(m, n), ld, type, is_a); + generate_NAN_INF(mat, uploa, (std::min)(m, n), ld, INF, is_a); + break; + case NO_EVT: + testinghelpers::datagenerators::randomgenerators( from, to, storage, m, n, mat, ld); + break; + default: ; + } +} + template void test_trsm( char storage, char side, char uploa, char transa, char diaga, - gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, double thresh ) + gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, double thresh, + EVT_TYPE a_init = NO_EVT, EVT_TYPE b_init = NO_EVT) { gtint_t mn; testinghelpers::set_dim_with_side( side, m, n, &mn ); @@ -54,18 +199,43 @@ void test_trsm( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- gtint_t lower = (diaga = 'n')||(diaga = 'N') ? 3 : 0; gtint_t upper = (diaga = 'n')||(diaga = 'N') ? 10 : 1; - std::vector a = testinghelpers::get_random_matrix( lower, upper, storage, transa, mn, mn, lda ); - std::vector b = testinghelpers::get_random_matrix( 3, 10, storage, 'n', m, n, ldb ); + std::vector a( testinghelpers::matsize(storage, transa, mn, mn, lda) ); + std::vector b( testinghelpers::matsize(storage, 'n', m, n, ldb) ); + srand(time(0)); + random_generator_with_INF_NAN( a.data(), uploa, storage, transa, lower, upper, mn, mn, lda, NO_EVT, true); - // Making A diagonally dominant so that the condition number is good and - // the algorithm doesn't diverge. - for (gtint_t i=0; i()) + random_generator_with_INF_NAN( b.data(), uploa, storage, 'n', 3, 10, m, n, ldb, b_init, false); + else + { + // Matrix B should not be read, only set. + testinghelpers::set_matrix( storage, m, n, b.data(), 'n', ldb, testinghelpers::aocl_extreme() ); + } + + // Create a copy of b so that we can check reference results. std::vector b_ref(b); + bool nan_inf_check = false; + // Setting the nan_inf_check boolean to true if alpha has + // Nan/Inf in it + if constexpr (testinghelpers::type_info::is_real) + { + nan_inf_check = (isnan(alpha) || isinf(alpha)); + } + else + { + nan_inf_check = (isnan(alpha.real + alpha.imag) || isinf(alpha.real + alpha.imag)); + } + nan_inf_check = ( nan_inf_check || + ((a_init != NO_EVT) && (a_init != ZERO)) || + ((b_init != NO_EVT) && (a_init != ZERO)) ); + testinghelpers::make_triangular( storage, uploa, mn, a.data(), lda ); //---------------------------------------------------------- // Call BLIS function @@ -81,5 +251,85 @@ void test_trsm( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, b.data(), b_ref.data(), ldb, thresh ); + computediff( "B", storage, m, n, b.data(), b_ref.data(), ldb, thresh, nan_inf_check ); + +#ifdef CAN_TEST_INFO_VALUE + gtint_t info = bli_info_get_info_value(); + computediff( "info", info, 0 ); +#endif } + +// Test-case logger : Used to print the test-case details based on parameters +template +class trsmGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char transa = std::get<3>(str.param); + char diaga = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T alpha = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diag_" + std::string(&diaga, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + return str_name; + } +}; + +template +class trsmEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char storage = std::get<0>(str.param); + char side = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char transa = std::get<3>(str.param); + char diaga = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T alpha = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + EVT_TYPE a_encode = std::get<10>(str.param); + EVT_TYPE b_encode = std::get<11>(str.param); + + std::string str_name = API_PRINT; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_side_" + std::string(&side, 1); + str_name += "_uploa_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diag_" + std::string(&diaga, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name = str_name + "_a_evt_" + std::to_string(a_encode); + str_name = str_name + "_b_evt_" + std::to_string(b_encode); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/level3/trsm/trsm.h b/gtestsuite/testsuite/level3/trsm/trsm.h index bb7f0469e2..edd36fc883 100644 --- a/gtestsuite/testsuite/level3/trsm/trsm.h +++ b/gtestsuite/testsuite/level3/trsm/trsm.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,12 +36,14 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Performs the operation: - * B := alpha*op( A )*B, or B := alpha*B*op( A ) + * op( A )*X = alpha*B, or X*op( A ) = alpha*B, * where op( A ) is one of * op( A ) = A or op( A ) = A**T or op( A ) = A**H, + * The matrix X is overwritten on B. * @param[in] storage specifies storage format used for the matrices * @param[in] side specifies if the symmetric matrix A appears left or right in the matrix multiplication @@ -78,6 +80,22 @@ static void trsm_( char side, char uploa, char transa, char diaga, gtint_t m, throw std::runtime_error("Error in testsuite/level3/trsm.h: Invalid typename in trsm_()."); } +template +static void trsm_blis_impl( char side, char uploa, char transa, char diaga, gtint_t m, + gtint_t n, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb ) +{ + if constexpr (std::is_same::value) + strsm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + dtrsm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + ctrsm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else if constexpr (std::is_same::value) + ztrsm_blis_impl( &side, &uploa, &transa, &diaga, &m, &n, alpha, ap, &lda, bp, &ldb ); + else + throw std::runtime_error("Error in testsuite/level3/trsm.h: Invalid typename in trsm_blis_impl()."); +} + template static void cblas_trsm( char storage, char side, char uploa, char transa, char diaga, gtint_t m, gtint_t n, T* alpha, T* ap, gtint_t lda, @@ -154,12 +172,50 @@ template static void trsm( char storage, char side, char uploa, char transa, char diaga, gtint_t m, gtint_t n, T *alpha, T *ap, gtint_t lda, T *bp, gtint_t ldb ) { + +#ifdef TEST_UPPERCASE_ARGS + storage = static_cast(std::toupper(static_cast(storage))); + side = static_cast(std::toupper(static_cast(side))); + uploa = static_cast(std::toupper(static_cast(uploa))); + transa = static_cast(std::toupper(static_cast(transa))); + diaga = static_cast(std::toupper(static_cast(diaga))); +#endif + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + char storage_cpy = storage; + char side_cpy = side; + char uploa_cpy = uploa; + char transa_cpy = transa; + char diaga_cpy = diaga; + gtint_t m_cpy = m; + gtint_t n_cpy = n; + T* alpha_cpy = alpha; + gtint_t lda_cpy = lda; + gtint_t ldb_cpy = ldb; + + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + // Create copy of input arrays so we can check that they are not altered. + T* ap_cpy = nullptr; + gtint_t size_ap = testinghelpers::matsize( storage, transa, mn, mn, lda ); + if (ap && size_ap > 0) + { + ap_cpy = new T[size_ap]; + memcpy( ap_cpy, ap, size_ap * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS if( storage == 'c' || storage == 'C' ) trsm_( side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); else throw std::runtime_error("Error in testsuite/level3/trsm.h: BLAS interface cannot be tested for row-major order."); - +#elif TEST_BLAS_BLIS_IMPL + if( storage == 'c' || storage == 'C' ) + trsm_blis_impl( side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); + else + throw std::runtime_error("Error in testsuite/level3/trsm.h: BLAS_BLIS_IMPL interface cannot be tested for row-major order."); #elif TEST_CBLAS cblas_trsm( storage, side, uploa, transa, diaga, m, n, alpha, ap, lda, bp, ldb ); #elif TEST_BLIS_TYPED @@ -167,4 +223,31 @@ static void trsm( char storage, char side, char uploa, char transa, char diaga, #else throw std::runtime_error("Error in testsuite/level3/trsm.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "storage", storage, storage_cpy ); + computediff( "side", side, side_cpy ); + computediff( "uploa", uploa, uploa_cpy ); + computediff( "transa", transa, transa_cpy ); + computediff( "diaga", diaga, diaga_cpy ); + computediff( "m", m, m_cpy ); + computediff( "n", n, n_cpy ); + if (alpha) computediff( "alpha", *alpha, *alpha_cpy ); + computediff( "lda", lda, lda_cpy ); + computediff( "ldb", ldb, ldb_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (ap && size_ap > 0) + { + computediff( "A", storage, mn, mn, ap, ap_cpy, lda, true ); + delete[] ap_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_evt.cpp b/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_evt.cpp new file mode 100644 index 0000000000..257928ac54 --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_evt.cpp @@ -0,0 +1,169 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class ztrsmEVT : + public ::testing::TestWithParam> {}; // EVT test for B + + +TEST_P( ztrsmEVT, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + EVT_TYPE a_init = std::get<10>(GetParam()); + EVT_TYPE b_init = std::get<11>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, a_init, b_init ); +} + +/** + * @brief Test ZTRSM for extreme values + * Code paths taken for: + * TRSV -> 1 + * AVX2 Small -> 151, 82 + * Native -> 503, 512 + */ +INSTANTIATE_TEST_SUITE_P( + evt, + ztrsmEVT, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 151, 503), // m + ::testing::Values(1, 82, 512), // n + ::testing::Values(dcomplex{-2.4, 2.0}, + dcomplex{-0.0, 2.3}, + dcomplex{-2.4, 0.0}, + dcomplex{ 0.0, 0.0}), // alpha + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, DIAG_NaN, DIAG_INF, + NEG_INF, NEG_NaN), // EVT test for A + ::testing::Values(NO_EVT, NaN, INF, NaN_INF, NEG_INF, NEG_NaN) // EVT test for B + ), + ::trsmEVTPrint() + ); + +/** + * @brief Test ZTRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 3 + * TRSM_NATIVE -> 501 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + ztrsmEVT, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 3, 501), // n + ::testing::Values(1, 3, 501), // m + ::testing::Values(dcomplex{NAN, -2.0}, + dcomplex{-2.0, NAN}, + dcomplex{INFINITY, 3.1f}, + dcomplex{NAN, -INFINITY}), // alpha + ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b + ::testing::Values(NO_EVT), // EVT test for A + ::testing::Values(NO_EVT) // EVT test for B + ), + ::trsmEVTPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_generic.cpp new file mode 100644 index 0000000000..2b4fe6aaca --- /dev/null +++ b/gtestsuite/testsuite/level3/trsm/ztrsm/ztrsm_generic.cpp @@ -0,0 +1,194 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "level3/trsm/test_trsm.h" + +class ztrsmGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( ztrsmGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // specifies matrix A appears left or right in + // the matrix multiplication + char side = std::get<1>(GetParam()); + // specifies upper or lower triangular part of A is used + char uploa = std::get<2>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<3>(GetParam()); + // denotes whether matrix a in unit or non-unit diagonal + char diaga = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // specifies alpha value + T alpha = std::get<7>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<8>(GetParam()); + gtint_t ldb_inc = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + if ( side == 'l' || side == 'L' ) + thresh = 3*m*testinghelpers::getEpsilon(); + else + thresh = 3*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); +} + +/** + * @brief Test ZTRSM native path, which starts from size 501 for BLAS api + * and starts from size 0 for BLIS api. + */ +INSTANTIATE_TEST_SUITE_P( + Native, + ztrsmGeneric, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS_LIKE + ,'r' +#endif + ), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n','c','t'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 53, 520), // m + ::testing::Values(1, 38, 511), // n + ::testing::Values(dcomplex{2.0,-1.0}), // alpha + ::testing::Values(gtint_t(20)), // increment to the leading dim of a + ::testing::Values(gtint_t(33)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test ZTRSM small avx2 path all fringe cases + * Kernel size for avx2 small path is 4x3, testing in range of + * 1 to 4 ensures all finge cases are being tested. + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2_fringe, + ztrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Range(gtint_t(1), gtint_t(5), 1), // m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // n + ::testing::Values(dcomplex{2.0,-3.4}), // alpha + ::testing::Values(gtint_t(56)), // increment to the leading dim of a + ::testing::Values(gtint_t(33)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test ZTRSM small avx2 path, this code path is used in range 0 to 500 + */ +INSTANTIATE_TEST_SUITE_P( + Small_AVX2, + ztrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(17, 500), // m + ::testing::Values(48, 500), // n + ::testing::Values(dcomplex{2.0,-3.4}), // alpha + ::testing::Values(gtint_t(54)), // increment to the leading dim of a + ::testing::Values(gtint_t(37)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); + +/** + * @brief Test ZTRSM with differnt values of alpha + * code paths covered: + * TRSV -> 1 + * TRSM_AVX2_small -> 3 + * TRSM_NATIVE -> 501 + */ +INSTANTIATE_TEST_SUITE_P( + Alpha, + ztrsmGeneric, + ::testing::Combine( + ::testing::Values('c'), // storage format + ::testing::Values('l','r'), // side l:left, r:right + ::testing::Values('u','l'), // uplo u:upper, l:lower + ::testing::Values('n', 'c', 't'), // transa + ::testing::Values('n','u'), // diaga , n=nonunit u=unit + ::testing::Values(1, 3, 501), // n + ::testing::Values(1, 3, 501), // m + ::testing::Values(dcomplex{2.0, 0.0}, dcomplex{0.0, -10.0}, + dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}), // alpha + ::testing::Values(gtint_t(0), gtint_t(65)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(23)) // increment to the leading dim of b + ), + ::trsmGenericPrint() + ); diff --git a/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp deleted file mode 100644 index 830b9081b5..0000000000 --- a/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp +++ /dev/null @@ -1,145 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_trsm.h" - -class ztrsmTest : - public ::testing::TestWithParam> {}; - -TEST_P(ztrsmTest, RandomData) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // matrix storage format(row major, column major) - char storage = std::get<0>(GetParam()); - // specifies matrix A appears left or right in - // the matrix multiplication - char side = std::get<1>(GetParam()); - // specifies upper or lower triangular part of A is used - char uploa = std::get<2>(GetParam()); - // denotes whether matrix a is n,c,t,h - char transa = std::get<3>(GetParam()); - // denotes whether matrix a in unit or non-unit diagonal - char diaga = std::get<4>(GetParam()); - // matrix size m - gtint_t m = std::get<5>(GetParam()); - // matrix size n - gtint_t n = std::get<6>(GetParam()); - // specifies alpha value - T alpha = std::get<7>(GetParam()); - // lda, ldb, ldc increments. - // If increments are zero, then the array size matches the matrix size. - // If increments are nonnegative, the array size is bigger than the matrix size. - gtint_t lda_inc = std::get<8>(GetParam()); - gtint_t ldb_inc = std::get<9>(GetParam()); - - // Set the threshold for the errors: - double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); -} - -class ztrsmTestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - char sfm = std::get<0>(str.param); - char side = std::get<1>(str.param); - char uploa = std::get<2>(str.param); - char transa = std::get<3>(str.param); - char diaga = std::get<4>(str.param); - gtint_t m = std::get<5>(str.param); - gtint_t n = std::get<6>(str.param); - dcomplex alpha = std::get<7>(str.param); - gtint_t lda_inc = std::get<8>(str.param); - gtint_t ldb_inc = std::get<9>(str.param); -#ifdef TEST_BLAS - std::string str_name = "ztrsm_"; -#elif TEST_CBLAS - std::string str_name = "cblas_ztrsm"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_ztrsm"; -#endif - str_name = str_name + "_" + sfm+sfm+sfm; - str_name = str_name + "_" + side + uploa + transa; - str_name = str_name + "_d" + diaga; - str_name = str_name + "_" + std::to_string(m); - str_name = str_name + "_" + std::to_string(n); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + std::to_string(lda_inc); - str_name = str_name + "_" + std::to_string(ldb_inc); - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - ztrsmTest, - ::testing::Combine( - ::testing::Values('c' -#ifndef TEST_BLAS - ,'r' -#endif - ), // storage format - ::testing::Values('l','r'), // side l:left, r:right - ::testing::Values('u','l'), // uplo u:upper, l:lower - ::testing::Values('n','c','t'), // transa - ::testing::Values('n','u'), // diaga , n=nonunit u=unit - ::testing::Range(gtint_t(10), gtint_t(11), 10), // m - ::testing::Range(gtint_t(10), gtint_t(11), 10), // n - ::testing::Values(dcomplex{1.0,2.0}), // alpha - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b - ), - ::ztrsmTestPrint() - ); diff --git a/gtestsuite/testsuite/ukr/addv/caddv_ukr.cpp b/gtestsuite/testsuite/ukr/addv/caddv_ukr.cpp new file mode 100644 index 0000000000..a981fee590 --- /dev/null +++ b/gtestsuite/testsuite/ukr/addv/caddv_ukr.cpp @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_addv_ukr.h" + +class caddvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caddvGeneric); + +// Defining the testsuite to check the accuracy of caddv micro-kernels +TEST_P( caddvGeneric, UKR ) +{ + using T = scomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + caddv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors + double threshold = 2 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_addv_ukr( ukr_fp, conj_x, n, incx, incy, threshold, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_caddv_zen_int kernel. + The code structure for bli_caddv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 48 --> L48 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_caddv_zen_int_unitStrides, + caddvGeneric, + ::testing::Combine( + ::testing::Values(bli_caddv_zen_int), // kernel address + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(48), // size n, for L48 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(128), // 2*L48 + L32 + gtint_t(127)), // 2*L48 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_caddv_zen_int_nonUnitStrides, + caddvGeneric, + ::testing::Combine( + ::testing::Values(bli_caddv_zen_int), // kernel address + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/addv/daddv_ukr.cpp b/gtestsuite/testsuite/ukr/addv/daddv_ukr.cpp new file mode 100644 index 0000000000..f0d91f1394 --- /dev/null +++ b/gtestsuite/testsuite/ukr/addv/daddv_ukr.cpp @@ -0,0 +1,193 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_addv_ukr.h" + +class daddvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daddvGeneric); + +// Defining the testsuite to check the accuracy of daddv micro-kernels +TEST_P( daddvGeneric, UKR ) +{ + using T = double; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + daddv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors + double threshold = 2 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_addv_ukr( ukr_fp, conj_x, n, incx, incy, threshold, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_daddv_zen_int kernel. + The code structure for bli_daddv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_daddv_zen_int_unitStrides, + daddvGeneric, + ::testing::Combine( + ::testing::Values(bli_daddv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(191)), // 2*L64 + L32 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_daddv_zen_int_nonUnitStrides, + daddvGeneric, + ::testing::Combine( + ::testing::Values(bli_daddv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + +// ---------------------------------------------- +// ----- Begin ZEN4/5 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_daddv_zen_int_avx512 kernel. + The code structure for bli_daddv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_daddv_zen_int_avx512_unitStrides, + daddvGeneric, + ::testing::Combine( + ::testing::Values(bli_daddv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(191)), // 2*L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_daddv_zen_int_avx512_nonUnitStrides, + daddvGeneric, + ::testing::Combine( + ::testing::Values(bli_daddv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4/5 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/addv/saddv_ukr.cpp b/gtestsuite/testsuite/ukr/addv/saddv_ukr.cpp new file mode 100644 index 0000000000..748e70f4b9 --- /dev/null +++ b/gtestsuite/testsuite/ukr/addv/saddv_ukr.cpp @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_addv_ukr.h" + +class saddvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saddvGeneric); + +// Defining the testsuite to check the accuracy of saddv micro-kernels +TEST_P( saddvGeneric, UKR ) +{ + using T = float; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + saddv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors + double threshold = 2 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_addv_ukr( ukr_fp, conj_x, n, incx, incy, threshold, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_saddv_zen_int kernel. + The code structure for bli_saddv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_saddv_zen_int_unitStrides, + saddvGeneric, + ::testing::Combine( + ::testing::Values(bli_saddv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(128), // size n, for L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(383)), // 2*L128 + L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_saddv_zen_int_nonUnitStrides, + saddvGeneric, + ::testing::Combine( + ::testing::Values(bli_saddv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/addv/test_addv_ukr.h b/gtestsuite/testsuite/ukr/addv/test_addv_ukr.h new file mode 100644 index 0000000000..7623108347 --- /dev/null +++ b/gtestsuite/testsuite/ukr/addv/test_addv_ukr.h @@ -0,0 +1,150 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/addv/addv.h" +#include "level1/ref_addv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for copyv operation. + */ + +template +void test_addv_ukr( FT ukr_fp, char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *y, *y_ref; + + // Sizes of x and y vectors + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + // Create the object for the required operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first greenzone for x + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; // y_ref does not have multiple greenzones + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( blis_conjx, n, x, incx, y, incy, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Call the ukr function, to check with the second redzone. + ukr_fp( blis_conjx, n, x, incx, y, incy, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_addv( conjx, n, x, incx, y_ref, incy ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y, y_ref, incy, thresh ); +} + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class addvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + bool is_memory_test = std::get<5>(str.param); + + std::string str_name = ""; + str_name += "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/addv/zaddv_ukr.cpp b/gtestsuite/testsuite/ukr/addv/zaddv_ukr.cpp new file mode 100644 index 0000000000..88fc82398e --- /dev/null +++ b/gtestsuite/testsuite/ukr/addv/zaddv_ukr.cpp @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_addv_ukr.h" + +class zaddvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaddvGeneric); + +// Defining the testsuite to check the accuracy of zaddv micro-kernels +TEST_P( zaddvGeneric, UKR ) +{ + using T = dcomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + zaddv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors + double threshold = 2 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_addv_ukr( ukr_fp, conj_x, n, incx, incy, threshold, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zaddv_zen_int kernel. + The code structure for bli_zaddv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 24 --> L24 + Fringe loops : In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zaddv_zen_int_unitStrides, + zaddvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaddv_zen_int), // kernel address + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(24), // size n, for L24 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + gtint_t(64), // 2*L24 + L16 + gtint_t(63)), // 2*L24 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zaddv_zen_int_nonUnitStrides, + zaddvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaddv_zen_int), // kernel address + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::addvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/amaxv/damaxv_ukr.cpp b/gtestsuite/testsuite/ukr/amaxv/damaxv_ukr.cpp new file mode 100644 index 0000000000..10bc9c6bde --- /dev/null +++ b/gtestsuite/testsuite/ukr/amaxv/damaxv_ukr.cpp @@ -0,0 +1,182 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_amaxv_ukr.h" + +class damaxvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(damaxvGeneric); + +// Tests using random integers as vector elements. +TEST_P( damaxvGeneric, UKR ) +{ + using T = double; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + damaxv_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_amaxv_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_damaxv_zen_int kernel. + The code structure for bli_damaxv_zen_int( ... ) is as follows : + + bli_damaxv_zen_int() --> bli_vec_absmax_double() --> bli_vec_search_double() + bli_vec_absmax_double() structure: + For unit strides : + Main loop : In blocks of 48 --> L48 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. + + bli_vec_search_double() structure: + For unit strides : + Main loop : In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_damaxv_zen_int_unitStrides, + damaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_damaxv_zen_int), // kernel address + ::testing::Values(gtint_t(48), // for size n, L48 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + gtint_t(144), // 3*L48 + gtint_t(176), // 3*L48 + L32 + gtint_t(175)), // 3*L48 + L16 + L8 + L4 + L2 + LScalar + ::testing::Values(gtint_t(1)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_damaxv_zen_int_nonUnitStrides, + damaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_damaxv_zen_int), // kernel address + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_damaxv_zen_int_avx512 kernel. + The code structure for bli_damaxv_zen_int_avx512( ... ) is as follows : + + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_damaxv_zen_int_avx512_unitStrides, + damaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_damaxv_zen_int_avx512), // kernel address + ::testing::Values(gtint_t(32), // for size n, L32 + gtint_t(16), // 2*L8 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(160), // 5*L32 + gtint_t(168), // 5*L32 + L8 + gtint_t(175), // 5*L32 + L8 + 7(LScalar) + gtint_t(191)), // 5*L32 + 3*L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_damaxv_zen_int_avx512_nonUnitStrides, + damaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_damaxv_zen_int_avx512), // kernel address + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/amaxv/samaxv_ukr.cpp b/gtestsuite/testsuite/ukr/amaxv/samaxv_ukr.cpp new file mode 100644 index 0000000000..e6c1010959 --- /dev/null +++ b/gtestsuite/testsuite/ukr/amaxv/samaxv_ukr.cpp @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_amaxv_ukr.h" + +class samaxvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(samaxvGeneric); + +// Tests using random integers as vector elements. +TEST_P( samaxvGeneric, UKR ) +{ + using T = float; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + samaxv_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_amaxv_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_samaxv_zen_int kernel. + The code structure for bli_samaxv_zen_int( ... ) is as follows : + + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_samaxv_zen_int_unitStrides, + samaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_samaxv_zen_int), // kernel address + ::testing::Values(gtint_t(8), // for size n, L8 + gtint_t(7), // LScalar + gtint_t(40), // 5*L8 + gtint_t(47)), // 5*L8 + LScalar + ::testing::Values(gtint_t(1)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_samaxv_zen_int_nonUnitStrides, + samaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_samaxv_zen_int), // kernel address + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_samaxv_zen_int_avx512 kernel. + The code structure for bli_samaxv_zen_int_avx512( ... ) is as follows : + + For unit strides : + Main loop : In blocks of 80 --> L80 + Fringe loops : In blocks of 16 --> L16 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_samaxv_zen_int_avx512_unitStrides, + samaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_samaxv_zen_int_avx512), // kernel address + ::testing::Values(gtint_t(80), // for size n, L80 + gtint_t(48), // 3*L16 + gtint_t(16), // L16 + gtint_t(11), // 11(LScalar) + gtint_t(317)), // 3*L80 + 4*L16 + 13(LScalar) + ::testing::Values(gtint_t(1)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_samaxv_zen_int_avx512_nonUnitStrides, + samaxvGeneric, + ::testing::Combine( + ::testing::Values(bli_samaxv_zen_int_avx512), // kernel address + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // incx + ::testing::Values(false, true) // is_memory_test + ), + ::amaxvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/amaxv/test_amaxv_ukr.h b/gtestsuite/testsuite/ukr/amaxv/test_amaxv_ukr.h new file mode 100644 index 0000000000..9118bc57a3 --- /dev/null +++ b/gtestsuite/testsuite/ukr/amaxv/test_amaxv_ukr.h @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/amaxv/amaxv.h" +#include "level1/ref_amaxv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Test body for amaxv micro-kernels + */ + +template +void test_amaxv_ukr( FT ukr_fp, gtint_t n, gtint_t incx, double thresh, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *x_copy; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + + // Create the objects for the input operand + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + + // Creating x_copy, to save the contents of x(without any redzones) + testinghelpers::ProtectedBuffer x_copy_buffer( size_x, false, false ); + + // Acquire the first set of greenzones for x and y + x = ( T* )x_buffer.greenzone_1; + x_copy = ( T* )x_copy_buffer.greenzone_1; // For x_copy, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + + // Copying the contents of x to x_copy + memcpy( x_copy, x, size_x ); + + dim_t idx; + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( n, x, incx, &idx, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_copy, size_x ); + + // Call the ukr function, to check with the second redzone. + ukr_fp( n, x, incx, &idx, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + dim_t idx_ref = testinghelpers::ref_amaxv( n, x, incx ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "idx", idx, idx_ref ); +} + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class amaxvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + bool is_memory_test = std::get<3>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/axpbyv/caxpbyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpbyv/caxpbyv_ukr.cpp new file mode 100644 index 0000000000..dfdb42b96e --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpbyv/caxpbyv_ukr.cpp @@ -0,0 +1,186 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv_ukr.h" + +class caxpbyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caxpbyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( caxpbyvGeneric, UKR ) +{ + using T = scomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + caxpbyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // beta + T beta = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + // Like SCALV + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ZERO()) + // Like SCAL2V + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + // Like AXPYV + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + else if (alpha == testinghelpers::ONE()) + thresh = 2*testinghelpers::getEpsilon(); + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, beta, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_caxpbyv_zen_int kernel. + The code structure for bli_caxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Fringe loops : In blocks of 12 --> L12 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +INSTANTIATE_TEST_SUITE_P( + bli_caxpbyv_zen_int_unitStrides, + caxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_caxpbyv_zen_int), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(16), // size n, for L16 + gtint_t(12), // L12 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(112), // 7*L16 + gtint_t(124), // 7*L16 + L12 + gtint_t(120), // 7*L16 + L8 + gtint_t(119)), // 7*L16 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, 0.0}, scomplex{2.3, -3.7}), // alpha + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, 0.0}, scomplex{2.3, -3.7}), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + + ); + +INSTANTIATE_TEST_SUITE_P( + bli_caxpbyv_zen_int_nonUnitStrides, + caxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_caxpbyv_zen_int), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, 0.0}, scomplex{2.3, -3.7}), // alpha + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, 0.0}, scomplex{2.3, -3.7}), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpbyv/daxpbyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpbyv/daxpbyv_ukr.cpp new file mode 100644 index 0000000000..74b0ad5b22 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpbyv/daxpbyv_ukr.cpp @@ -0,0 +1,298 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv_ukr.h" + +class daxpbyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daxpbyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( daxpbyvGeneric, UKR ) +{ + using T = double; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + daxpbyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // beta + T beta = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + // Like SCALV + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ZERO()) + // Like SCAL2V + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + // Like AXPYV + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + else if (alpha == testinghelpers::ONE()) + thresh = 2*testinghelpers::getEpsilon(); + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, beta, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_daxpbyv_zen_int10 kernel. + The code structure for bli_daxpbyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 40 --> L40 + Fringe loops : In blocks of 20 --> L20 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +// Unit testing with unit stride, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int10_unitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(40), // size n, for L40 + gtint_t(20), // L20 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // LScalar + // Testing the loops with combination + gtint_t(120), // 3*L40 + gtint_t(140), // 3*L40 + L20 + gtint_t(148), // 3*L40 + L20 + L8 + gtint_t(152), // 3*L40 + L20 + L8 + L4 + gtint_t(155)), // 3*L40 + L20 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + ((::axpbyvMemUKRPrint())) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int10_nonUnitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); + +/* + Unit testing for functionality of bli_daxpbyv_zen_int kernel. + The code structure for bli_daxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int_unitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(16), // size n, for L16 + gtint_t(48), // 3*L16 + gtint_t(57)), // 3*L16 + 9(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); + +// Unit testing for Non-Unit Stride +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int_nonUnitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_daxpbyv_zen_int_avx512 kernel. + The code structure for bli_daxpbyv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +// Unit testing with unit stride, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int_avx512_unitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(191)), // 2*L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + ((::axpbyvMemUKRPrint())) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_daxpbyv_zen_int_avx512_nonUnitStrides, + daxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpbyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpbyv/saxpbyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpbyv/saxpbyv_ukr.cpp new file mode 100644 index 0000000000..a0a5c38f15 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpbyv/saxpbyv_ukr.cpp @@ -0,0 +1,235 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv_ukr.h" + +class saxpbyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saxpbyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( saxpbyvGeneric, UKR ) +{ + using T = float; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + saxpbyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // beta + T beta = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + // Like SCALV + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ZERO()) + // Like SCAL2V + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + // Like AXPYV + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + else if (alpha == testinghelpers::ONE()) + thresh = 2*testinghelpers::getEpsilon(); + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, beta, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_saxpbyv_zen_int10 kernel. + The code structure for bli_saxpbyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 80 --> L80 + Fringe loops : In blocks of 40 --> L40 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +// Unit testing with unit stride, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_saxpbyv_zen_int10_unitStride, + saxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpbyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(80), // size n, for L80 + gtint_t(40), // L40 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + // Testing the loops with combination + gtint_t(240), // 3*L80 + gtint_t(312), // 3*L80 + L40 + L32 + gtint_t(271)), // 3*L80 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // alpha + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + ((::axpbyvMemUKRPrint())) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_saxpbyv_zen_int_unitStride, + saxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpbyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // alpha + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); + +/* + Unit testing for functionality of bli_saxpbyv_zen_int kernel. + The code structure for bli_saxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_saxpbyv_zen_int_unitStrides, + saxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpbyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(32), // size n, for L32 + gtint_t(96), // 3*L32 + gtint_t(111)), // 3*L32 + 15(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // alpha + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); + +// Unit testing for Non-Unit Stride +INSTANTIATE_TEST_SUITE_P( + bli_saxpbyv_zen_int_nonUnitStrides, + saxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpbyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // alpha + ::testing::Values(float(1.0), float(-1.0), + float(2.2), float(-4.1), + float(0.0)), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpbyv/test_axpbyv_ukr.h b/gtestsuite/testsuite/ukr/axpbyv/test_axpbyv_ukr.h new file mode 100644 index 0000000000..b6eae7a8c4 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpbyv/test_axpbyv_ukr.h @@ -0,0 +1,176 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/axpbyv/axpbyv.h" +#include "level1/ref_axpbyv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for axpby operation. + */ + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_axpbyv_ukr( FT ukr_fp, char conjx, gtint_t n, gtint_t incx, gtint_t incy, + T alpha, T beta, double thresh, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *y, *y_ref; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + // Create the objects for the input and output operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for x and y + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; // For y_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( blis_conjx, n, &alpha, x, incx, &beta, y, incy, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Call the ukr function, to check with the second redzone. + ukr_fp( blis_conjx, n, &alpha, x, incx, &beta, y, incy, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_axpbyv( conjx, n, alpha, x, incx, beta, y_ref, incy ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y, y_ref, incy, thresh ); +} + + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class axpbyvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + T1 alpha = std::get<5>(str.param); + T1 beta = std::get<6>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + return str_name; + } +}; + +template +class axpbyvMemUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + T1 alpha = std::get<5>(str.param); + T1 beta = std::get<6>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/axpbyv/zaxpbyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpbyv/zaxpbyv_ukr.cpp new file mode 100644 index 0000000000..9eb87cc6f7 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpbyv/zaxpbyv_ukr.cpp @@ -0,0 +1,190 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpbyv_ukr.h" + +class zaxpbyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaxpbyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zaxpbyvGeneric, UKR ) +{ + using T = dcomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + zaxpbyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // beta + T beta = std::get<6>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<7>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + // Like SCALV + if (beta == testinghelpers::ZERO() || beta == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ZERO()) + // Like SCAL2V + if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + else if (beta == testinghelpers::ONE()) + // Like AXPYV + if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + else if (alpha == testinghelpers::ONE()) + thresh = 2*testinghelpers::getEpsilon(); + else + thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpbyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, beta, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zaxpbyv_zen_int kernel. + The code structure for bli_zaxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_unitStrides, + zaxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpbyv_zen_int), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(8), // size n, for L8 + gtint_t(6), // L6 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // L1 + gtint_t(56), // 7*L8 + gtint_t(62), // 7*L8 + L6 + gtint_t(60), // 7*L8 + L4 + gtint_t(58), // 7*L8 + L2 + gtint_t(57), // 7*L8 + 1(LScalar) + gtint_t(59), // 7*L8 + L2 + 1(LScalar) + gtint_t(61), // 7*L8 + L4 + 1(LScalar) + gtint_t(63)), // 7*L8 + L6 + 1(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, 0.0}, dcomplex{2.3, -3.7}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, 0.0}, dcomplex{2.3, -3.7}), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_nonUnitStrides, + zaxpbyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpbyv_zen_int), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, 0.0}, dcomplex{2.3, -3.7}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, 0.0}, dcomplex{2.3, -3.7}), // beta + ::testing::Values(false, true) // is_memory_test + ), + (::axpbyvMemUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpyf/daxpyf_ukr.cpp b/gtestsuite/testsuite/ukr/axpyf/daxpyf_ukr.cpp new file mode 100644 index 0000000000..a9a3f9db2d --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyf/daxpyf_ukr.cpp @@ -0,0 +1,192 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyf_ukr.h" + +using T = double; +using FT = daxpyf_ker_ft; + +class daxpyfGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daxpyfGeneric); + +// Tests using random integers as vector elements. +TEST_P( daxpyfGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjugate for A + char conjA = std::get<1>(GetParam()); + // denotes conjugate for x + char conjx = std::get<2>(GetParam()); + // rows of matrix + gtint_t m = std::get<3>(GetParam()); + // fuse factor + gtint_t b_fuse = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // stride size for A + gtint_t inca = std::get<6>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<7>(GetParam()); + // stride size for x + gtint_t incx = std::get<8>(GetParam()); + // stride size for y + gtint_t incy = std::get<9>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyf.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + { + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 4.0; +#endif + + thresh = adj*(2*b_fuse)*testinghelpers::getEpsilon(); + } + else + { + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 2.0; +#else + double adj = 4.7; +#endif + thresh = adj*(3*b_fuse)*testinghelpers::getEpsilon(); + } + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyf_ukr( ukr_fp, conjA, conjx, m, b_fuse, alpha, inca, lda_inc, incx, incy, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_daxpyf_zen_int_avx512 kernel. +*/ +// Unit testing with unit strides, across all fuse-factors. +INSTANTIATE_TEST_SUITE_P( + bli_daxpyf_zen_int_avx512_unitStrides, + daxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyf_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(1), + gtint_t(3), + gtint_t(5), + gtint_t(8), + gtint_t(16), + gtint_t(32), + gtint_t(55)), + ::testing::Values(// b_fuse + gtint_t(2), // bli_daxpyf_zen_int2_avx512 + gtint_t(4), // bli_daxpyf_zen_int4_avx512 + gtint_t(6), // bli_daxpyf_zen_int6_avx512 + gtint_t(8), // bli_daxpyf_zen_int8_avx512 + gtint_t(12), // bli_daxpyf_zen_int12_avx512 + gtint_t(16), // bli_daxpyf_zen_int16_avx512 + gtint_t(32), // bli_daxpyf_zen_int32_avx512 + gtint_t(30), // Combination of fuse factors 16, 8, 6 + gtint_t(28), // Combination of fuse factors 16, 8, 4 + gtint_t(26) // Combination of fuse factors 16, 8, 2 + ), + ::testing::Values( -2.1, -1.0, 0.0, 1.0, 2.1 ), // alpha + ::testing::Values(gtint_t(1)), // inca + ::testing::Values(gtint_t(0), gtint_t(1)), // lda_inc + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +// Unit testing with non-unit strides, across all fuse-factors. +INSTANTIATE_TEST_SUITE_P( + bli_daxpyf_zen_int_avx512_nonUnitStrides, + daxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyf_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(15), gtint_t(27)), // for size n + ::testing::Values(// b_fuse + gtint_t(2), // bli_daxpyf_zen_int2_avx512 + gtint_t(4), // bli_daxpyf_zen_int4_avx512 + gtint_t(6), // bli_daxpyf_zen_int6_avx512 + gtint_t(8), // bli_daxpyf_zen_int8_avx512 + gtint_t(16), // bli_daxpyf_zen_int16_avx512 + gtint_t(32) // bli_daxpyf_zen_int32_avx512 + ), + ::testing::Values( -2.1, 0.0, 1.0, 2.1 ), // alpha + ::testing::Values(gtint_t(2)), // inca + ::testing::Values(gtint_t(3)), // lda_inc + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpyf/test_axpyf_ukr.h b/gtestsuite/testsuite/ukr/axpyf/test_axpyf_ukr.h new file mode 100644 index 0000000000..e33026136d --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyf/test_axpyf_ukr.h @@ -0,0 +1,191 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/axpyf/axpyf.h" +#include "level1/ref_axpyf.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for axpby operation. + */ + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_axpyf_ukr( FT ukr_fp, char conjA, char conjx, gtint_t m, gtint_t b_fuse, + T alpha, gtint_t inca, gtint_t lda_inc, gtint_t incx, gtint_t incy, + double thresh, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *A, *x, *y, *y_ref; + + // Compute the leading dimensions of A matrix. + gtint_t lda = testinghelpers::get_leading_dimension( 'c', 'n', m, b_fuse, lda_inc, inca ); + + // Compute the sizes required to allocate memory for the operands + gtint_t size_A = lda * b_fuse * sizeof( T ); + gtint_t size_x = testinghelpers::buff_dim( b_fuse, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( m, incy ) * sizeof( T ); + + // Create the objects for the input and output operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer A_buffer( size_A, false, false ); + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for A, x and y + A = ( T* )A_buffer.greenzone_1; + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; // For y_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -2, 8, 'c', m, b_fuse, A, 'n', lda ); + testinghelpers::datagenerators::randomgenerators( -10, 10, b_fuse, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, m, incy, y ); + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); + + // Char conjA and conjx to BLIS conjA and conjx conversion + conj_t blis_conjA, blis_conjx; + testinghelpers::char_to_blis_conj( conjA, &blis_conjA ); + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp + ( + blis_conjA, blis_conjx, + m, b_fuse, &alpha, + A, inca, lda, x, incx, + y, incy, nullptr + ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + A = ( T* )A_buffer.greenzone_2; + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for A, x and y accordingly + memcpy( A, A_buffer.greenzone_1, size_A ); + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Call the ukr function, to check with the second redzone. + ukr_fp + ( + blis_conjA, blis_conjx, + m, b_fuse, &alpha, + A, inca, lda, x, incx, + y, incy, nullptr + ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_axpyf + ( + conjA, conjx, m, b_fuse, + &alpha, A, inca, lda, + x, incx, y_ref, incy + ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", m, y, y_ref, incy, thresh ); +} + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class axpyfUkrPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjA = std::get<1>(str.param); + char conjx = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t b_fuse = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + gtint_t inca = std::get<6>(str.param); + gtint_t lda_inc = std::get<7>(str.param); + gtint_t incx = std::get<8>(str.param); + gtint_t incy = std::get<9>(str.param); + bool is_memory_test = std::get<10>(str.param); + + std::string str_name = ""; + str_name += "_m_" + std::to_string(m); + str_name += "_bf_" + std::to_string(b_fuse); + str_name += "_conja_" + std::string(&conjA, 1); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_inca_" + testinghelpers::get_value_string(inca); + gtint_t lda = testinghelpers::get_leading_dimension( 'c', 'n', m, b_fuse, lda_inc, inca ); + str_name += "_lda_i" + testinghelpers::get_value_string(lda_inc) + "_" + std::to_string(lda);; + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/axpyf/zaxpyf_ukr.cpp b/gtestsuite/testsuite/ukr/axpyf/zaxpyf_ukr.cpp new file mode 100644 index 0000000000..da932f5e07 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyf/zaxpyf_ukr.cpp @@ -0,0 +1,321 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyf_ukr.h" + +using T = dcomplex; +using FT = zaxpyf_ker_ft; + +class zaxpyfGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaxpyfGeneric); + +// Tests using random integers as vector elements. +TEST_P( zaxpyfGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjugate for A + char conjA = std::get<1>(GetParam()); + // denotes conjugate for x + char conjx = std::get<2>(GetParam()); + // rows of matrix + gtint_t m = std::get<3>(GetParam()); + // fuse factor + gtint_t b_fuse = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // stride size for A + gtint_t inca = std::get<6>(GetParam()); + // lda_inc for A + gtint_t lda_inc = std::get<7>(GetParam()); + // stride size for x + gtint_t incx = std::get<8>(GetParam()); + // stride size for y + gtint_t incy = std::get<9>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpyf.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + + // NOTE : Each multiplication of dcomplex elements results in three + // ops(two muls and 1 add) for real and imag part of the result. + double thresh; + if (m == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = 0.0; + else if (alpha == testinghelpers::ONE()) + thresh = (4*b_fuse)*testinghelpers::getEpsilon(); + else + thresh = (7*b_fuse)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyf_ukr( ukr_fp, conjA, conjx, m, b_fuse, alpha, inca, lda_inc, incx, incy, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_zaxpyf_zen_int_2_avx512 kernel. + The code structure for bli_zaxpyf_zen_int_2_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 4 --> L4 + Masked loop ---> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_2_avx512_unitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(// Testing the loops standalone + gtint_t(8), // for size n, L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(24), // 3*L8 + gtint_t(28), // 3*L8 + L4 + gtint_t(31)), // 3*L8 + L4 + LScalar + ::testing::Values(gtint_t(2)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(1)), // inca + ::testing::Values(gtint_t(1)), // lda_inc + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_2_avx512_nonUnitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(gtint_t(15), gtint_t(27)), // for size n + ::testing::Values(gtint_t(2)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(2)), // inca + ::testing::Values(gtint_t(3)), // lda_inc + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +/* + Unit testing for functionality of bli_zaxpyf_zen_int_4_avx512 kernel. + The code structure for bli_zaxpyf_zen_int_4_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 4 --> L4 + Masked loop ---> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_4_avx512_unitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(// Testing the loops standalone + gtint_t(8), // for size n, L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(24), // 3*L8 + gtint_t(28), // 3*L8 + L4 + gtint_t(31)), // 3*L8 + L4 + LScalar + ::testing::Values(gtint_t(4)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(1)), // inca + ::testing::Values(gtint_t(1)), // lda_inc + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_4_avx512_nonUnitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(gtint_t(15), gtint_t(27)), // for size n + ::testing::Values(gtint_t(4)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(2)), // inca + ::testing::Values(gtint_t(3)), // lda_inc + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +/* + Unit testing for functionality of bli_zaxpyf_zen_int_8_avx512 kernel. + The code structure for bli_zaxpyf_zen_int_8_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 4 --> L4 + Masked loop ---> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_8_avx512_unitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(// Testing the loops standalone + gtint_t(8), // for size n, L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(24), // 3*L8 + gtint_t(28), // 3*L8 + L4 + gtint_t(31)), // 3*L8 + L4 + LScalar + ::testing::Values(gtint_t(8)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(1)), // inca + ::testing::Values(gtint_t(1)), // lda_inc + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyf_zen_int_8_avx512_nonUnitStrides, + zaxpyfGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyf_zen_int_8_avx512), // kernel address + ::testing::Values('n' +#if defined(TEST_BLIS_TYPED) + ,'c' +#endif + ), // conjA + ::testing::Values('n', 'c'), // conjx + ::testing::Values(gtint_t(15), gtint_t(27)), // for size n + ::testing::Values(gtint_t(8)), // b_fuse + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(gtint_t(2)), // inca + ::testing::Values(gtint_t(3)), // lda_inc + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + (::axpyfUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpyv/caxpyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpyv/caxpyv_ukr.cpp new file mode 100644 index 0000000000..9fcb7dc387 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyv/caxpyv_ukr.cpp @@ -0,0 +1,160 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyv_ukr.h" + +class caxpyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(caxpyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( caxpyvGeneric, UKR ) +{ + using T = scomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + caxpyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_caxpyv_zen_int5 kernel. + The code structure for bli_caxpyv_zen_int5( ... ) is as follows : + For unit strides : + Main loop : In blocks of 20 --> L20 + Fringe loops : In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_caxpyv_zen_int5_unitStrides, + caxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_caxpyv_zen_int5), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(20), // size n, for L20 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combination + gtint_t(60), // 3*L20 + gtint_t(68), // 3*L20 + L8 + gtint_t(67)), // 3*L20 + L4 + LScalar + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, -3.3}, scomplex{4.3,-2.1}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_caxpyv_zen_int5_nonUnitStrides, + caxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_caxpyv_zen_int5), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(gtint_t(2)), // n, size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, -3.3}, scomplex{4.3,-2.1}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +#endif diff --git a/gtestsuite/testsuite/ukr/axpyv/daxpyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpyv/daxpyv_ukr.cpp new file mode 100644 index 0000000000..6e5832c767 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyv/daxpyv_ukr.cpp @@ -0,0 +1,271 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyv_ukr.h" + +class daxpyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(daxpyvGeneric); + +// Defining the testsuite to check the accuracy of daxpyv micro-kernels +TEST_P( daxpyvGeneric, UKR ) +{ + using T = double; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + daxpyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_daxpyv_zen_int10 kernel. + The code structure for bli_daxpyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 52 --> L52 + Fringe loops : In blocks of 40 --> L40 + In blocks of 20 --> L20 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int10_unitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(52), // size n, for L52 + gtint_t(40), // L40 + gtint_t(20), // L20 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // LScalar + // Testing the loops with combination + gtint_t(156), // 3*L52 + gtint_t(196), // 3*L52 + L40 + gtint_t(204), // 3*L52 + L40 + L8 + gtint_t(203), // 3*L52 + L40 + L4 + 3(LScalar) + gtint_t(176), // 3*L52 + L20 + gtint_t(192), // 3*L52 + L20 + L16 + gtint_t(191)), // 3*L52 + L20 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int10_nonUnitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +/* + Unit testing for functionality of bli_daxpyv_zen_int kernel. + The code structure for bli_daxpyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Element wise loop post all these loops. + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int_unitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(16), // size n, for L16 + gtint_t(48), // 3*L16 + gtint_t(89)), // 5*L16 + 9(scalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int_nonUnitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_daxpyv_zen_int_avx512 kernel. + The code structure for bli_daxpyv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int_avx512_unitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + gtint_t(320), // 5*L64 + gtint_t(352), // 5*L64 + L32 + gtint_t(368), // 5*L64 + L32 + L16 + gtint_t(376), // 5*L64 + L32 + L16 + L8 + gtint_t(380), // 5*L64 + L32 + L16 + L8 + L4 + gtint_t(383)), // 5*L64 + L32 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_daxpyv_zen_int_avx512_nonUnitStrides, + daxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_daxpyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(gtint_t(10), // n, size of the vector + gtint_t(25)), + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.2), double(-4.1), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpyv/saxpyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpyv/saxpyv_ukr.cpp new file mode 100644 index 0000000000..afa2eb7297 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyv/saxpyv_ukr.cpp @@ -0,0 +1,254 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyv_ukr.h" + +class saxpyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(saxpyvGeneric); + +// Defining the testsuite to check the accuracy of saxpyv micro-kernels +TEST_P( saxpyvGeneric, UKR ) +{ + using T = float; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + saxpyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors + double threshold = 2 * testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, threshold, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_saxpyv_zen_int10 kernel. + The code structure for bli_saxpyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 120 --> L120 + Fringe loops : In blocks of 80 --> L80 + In blocks of 40 --> L40 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int10_unitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(120), // size n, for L120 + gtint_t(80), // L80 + gtint_t(40), // L40 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(240), // 2*L120 + gtint_t(320), // 2*L120 + L80 + gtint_t(312), // 2*L120 + L40 + L32 + gtint_t(271)), // 2*L120 + L16 + L8 + LScalar + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int10_nonUnitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int10), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +/* + Unit testing for functionality of bli_saxpyv_zen_int kernel. + The code structure for bli_saxpyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int_unitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(32), // size n, for L32 + gtint_t(15), // LScalar + gtint_t(79)), // 2*L32 + LScalar + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int_nonUnitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(10)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_saxpyv_zen_int_avx512 kernel. + The code structure for bli_saxpyv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int_avx512_unitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(128), // size n, for L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(383)), // 2*L128 + L64 + L32 + L16 + L8 + L7 + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_saxpyv_zen_int_avx512_nonUnitStrides, + saxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_saxpyv_zen_int_avx512), // kernel address + ::testing::Values('n'), // use x, not conj(x) (since it is real) + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/axpyv/test_axpyv_ukr.h b/gtestsuite/testsuite/ukr/axpyv/test_axpyv_ukr.h new file mode 100644 index 0000000000..648562de23 --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyv/test_axpyv_ukr.h @@ -0,0 +1,152 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/axpyv/axpyv.h" +#include "level1/ref_axpyv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for axpby operation. + */ + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_axpyv_ukr( FT ukr_fp, char conjx, gtint_t n, gtint_t incx, gtint_t incy, + T alpha, double thresh, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *y, *y_ref; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + // Create the objects for the input and output operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for x and y + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; // For y_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( blis_conjx, n, &alpha, x, incx, y, incy, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Call the ukr function, to check with the second redzone. + ukr_fp( blis_conjx, n, &alpha, x, incx, y, incy, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_axpyv( conjx, n, alpha, x, incx, y_ref, incy ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y, y_ref, incy, thresh ); + +} + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class axpyvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + T1 alpha = std::get<5>(str.param); + bool is_memory_test = std::get<6>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/axpyv/zaxpyv_ukr.cpp b/gtestsuite/testsuite/ukr/axpyv/zaxpyv_ukr.cpp new file mode 100644 index 0000000000..f2bb26a2fd --- /dev/null +++ b/gtestsuite/testsuite/ukr/axpyv/zaxpyv_ukr.cpp @@ -0,0 +1,238 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + Portions of this file consist of AI-generated content. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_axpyv_ukr.h" + +class zaxpyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zaxpyvGeneric); + +// Defining the testsuite to check the accuracy of zaxpyv micro-kernels +TEST_P( zaxpyvGeneric, UKR ) +{ + using T = dcomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // Assign the kernel address to the function pointer + zaxpyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether x or conj(x) will be added to y: + char conj_x = std::get<1>(GetParam()); + // vector length + gtint_t n = std::get<2>(GetParam()); + // stride size for x + gtint_t incx = std::get<3>(GetParam()); + // stride size for y + gtint_t incy = std::get<4>(GetParam()); + // alpha + T alpha = std::get<5>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite axpbyv.h (no netlib version) for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = 2*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_axpyv_ukr( ukr_fp, conj_x, n, incx, incy, alpha, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zaxpyv_zen_int5 kernel. + The code structure for bli_zaxpyv_zen_int10( ... ) is as follows : + For unit strides : + Main loop : In blocks of 14 --> L14 + Fringe loops : In blocks of 10 --> L10 + In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyv_zen_int5_unitStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyv_zen_int5), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(14), // size n, for L14 + gtint_t(10), // L10 + gtint_t(6), // L6 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + // Testing the loops with combination + gtint_t(42), // 3*L14 + gtint_t(52), // 3*L14 + L10 + gtint_t(48), // 3*L14 + L6 + gtint_t(46), // 3*L14 + L4 + gtint_t(45)), // 3*L14 + L2 + LScalar + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyv_zen_int5_nonUnitStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyv_zen_int5), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(gtint_t(2)), // n, size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_zaxpyv_zen_int_avx512 kernel. + The code structure for bli_zaxpyv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Masked loop ---> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyv_zen_int_avx512_unitStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyv_zen_int_avx512), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(32), // size n, for L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combination + gtint_t(96), // 3*L32 + gtint_t(112), // 3*L32 + L116 + gtint_t(120), // 3*L32 + L16 + L8 + gtint_t(124), // 3*L32 + L16 + L8 + L4 + gtint_t(127)), // 3*L32 + L16 + L8 + L4 + LScalar + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +// Unit testing for non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_zaxpyv_zen_int_avx512_nonUnitStrides, + zaxpyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zaxpyv_zen_int_avx512), // kernel address + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + ::testing::Values(gtint_t(13)), // n, size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::axpyvUKRPrint()) + ); + +#endif diff --git a/gtestsuite/testsuite/ukr/copyv/ccopyv_ukr.cpp b/gtestsuite/testsuite/ukr/copyv/ccopyv_ukr.cpp new file mode 100644 index 0000000000..2bd4b86138 --- /dev/null +++ b/gtestsuite/testsuite/ukr/copyv/ccopyv_ukr.cpp @@ -0,0 +1,135 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_copyv_ukr.h" + +class ccopyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ccopyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( ccopyvGeneric, UKR ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + ccopyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_copyv_ukr( ukr_fp, conjx, n, incx, incy, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_ccopyv_zen_int kernel. + The code structure for bli_ccopyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ccopyv_zen_int_unitStrides, + ccopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_ccopyv_zen_int), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(32), // size n, for L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + gtint_t(160), // 5*L32 + gtint_t(192), // 5*L32 + L16 + gtint_t(200), // 5*L32 + L16 + L8 + gtint_t(204), // 5*L32 + L16 + L8 + L4 + gtint_t(207)), // 5*L32 + L16 + L8 + L4 + 1(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ccopyv_zen_int_nonUnitStrides, + ccopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_ccopyv_zen_int), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/copyv/dcopyv_ukr.cpp b/gtestsuite/testsuite/ukr/copyv/dcopyv_ukr.cpp new file mode 100644 index 0000000000..c8c9e3f5ce --- /dev/null +++ b/gtestsuite/testsuite/ukr/copyv/dcopyv_ukr.cpp @@ -0,0 +1,191 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_copyv_ukr.h" + +class dcopyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dcopyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( dcopyvGeneric, UKR ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + dcopyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_copyv_ukr( ukr_fp, conjx, n, incx, incy, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_dcopyv_zen_int kernel. + The code structure for bli_dcopyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dcopyv_zen_int_unitStrides, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_dcopyv_zen_int), + ::testing::Values('n'), // conjugate parameter, 'n' for dcopyv + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + gtint_t(320), // 5*L64 + gtint_t(352), // 5*L64 + L32 + gtint_t(368), // 5*L64 + L32 + L16 + gtint_t(376), // 5*L64 + L32 + L16 + L8 + gtint_t(380), // 5*L64 + L32 + L16 + L8 + L4 + gtint_t(383)), // 5*L64 + L32 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dcopyv_zen_int_nonUnitStrides, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_dcopyv_zen_int), + ::testing::Values('n'), // conjugate parameter, 'n' for dcopyv + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_dcopyv_zen4_asm_avx512 kernel. + The code structure for bli_dcopyv_zen4_asm_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 256 --> L256 + Fringe loops : In blocks of 128 --> L128 + In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dcopyv_zen4_asm_avx512_unitStrides, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_dcopyv_zen4_asm_avx512), + ::testing::Values('n'), // conjugate parameter, 'n' for dcopyv + ::testing::Values(// Testing the loops standalone + gtint_t(256), // size n, for L256 + gtint_t(128), // L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + // Testing the loops with combinations + gtint_t(1280), // 5*L256 + gtint_t(1408), // 5*L256 + L128 + gtint_t(1472), // 5*L256 + L128 + L64 + gtint_t(1504), // 5*L256 + L128 + L64 + L32 + gtint_t(1520), // 5*L256 + L128 + L64 + L32 + L16 + gtint_t(1528), // 5*L256 + L128 + L64 + L32 + L16 + L8 + gtint_t(1535)), // 5*L256 + L128 + L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dcopyv_zen4_asm_avx512_nonUnitStrides, + dcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_dcopyv_zen4_asm_avx512), + ::testing::Values('n'), // conjugate parameter, 'n' for dcopyv + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/copyv/scopyv_ukr.cpp b/gtestsuite/testsuite/ukr/copyv/scopyv_ukr.cpp new file mode 100644 index 0000000000..906513f153 --- /dev/null +++ b/gtestsuite/testsuite/ukr/copyv/scopyv_ukr.cpp @@ -0,0 +1,191 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_copyv_ukr.h" + +class scopyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(scopyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( scopyvGeneric, UKR ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + scopyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_copyv_ukr( ukr_fp, conjx, n, incx, incy, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_scopyv_zen_int kernel. + The code structure for bli_scopyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_scopyv_zen_int_unitStrides, + scopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_scopyv_zen_int), + ::testing::Values('n'), // conjugate parameter, 'n' for scopyv + ::testing::Values(// Testing the loops standalone + gtint_t(128), // size n, for L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + // Testing the loops with combinations + gtint_t(640), // 5*L128 + gtint_t(704), // 5*L128 + L64 + gtint_t(736), // 5*L128 + L64 + L32 + gtint_t(752), // 5*L128 + L64 + L32 + L16 + gtint_t(760), // 5*L128 + L64 + L32 + L16 + L8 + gtint_t(767)), // 5*L128 + L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_scopyv_zen_int_nonUnitStrides, + scopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_scopyv_zen_int), + ::testing::Values('n'), // conjugate parameter, 'n' for scopyv + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_scopyv_zen4_asm_avx512 kernel. + The code structure for bli_scopyv_zen4_asm_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 512 --> L512 + Fringe loops : In blocks of 256 --> L256 + In blocks of 128 --> L128 + In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_scopyv_zen4_asm_avx512_unitStrides, + scopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_scopyv_zen4_asm_avx512), + ::testing::Values('n'), // conjugate parameter, 'n' for scopyv + ::testing::Values(// Testing the loops standalone + gtint_t(512), // size n, for L512 + gtint_t(256), // L256 + gtint_t(128), // L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(15), // LScalar + // Testing the loops with combinations + gtint_t(2560), // 5*L512 + gtint_t(2816), // 5*L512 + L256 + gtint_t(2944), // 5*L512 + L256 + L128 + gtint_t(3008), // 5*L512 + L256 + L128 + L64 + gtint_t(3040), // 5*L512 + L256 + L128 + L64 + L32 + gtint_t(3056), // 5*L512 + L256 + L128 + L64 + L32 + L16 + gtint_t(3071)), // 5*L512 + L256 + L128 + L64 + L32 + L16 + 15(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_scopyv_zen4_asm_avx512_nonUnitStrides, + scopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_scopyv_zen4_asm_avx512), + ::testing::Values('n'), // conjugate parameter, 'n' for scopyv + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/copyv/test_copyv_ukr.h b/gtestsuite/testsuite/ukr/copyv/test_copyv_ukr.h new file mode 100644 index 0000000000..7aef5e78b8 --- /dev/null +++ b/gtestsuite/testsuite/ukr/copyv/test_copyv_ukr.h @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/copyv/copyv.h" +#include "level1/ref_copyv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for copyv operation. + */ + +template +static void test_copyv_ukr( FT ukr_fp, char conjx, gtint_t n, gtint_t incx, gtint_t incy, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *y, *y_ref; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + // Create the objects for the input and output operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for x and y + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; // For y_ref, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying the contents of y to y_ref + memcpy( y_ref, y, size_y ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( blis_conjx, n, x, incx, y, incy, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Call the ukr function, to check with the second redzone. + ukr_fp( blis_conjx, n, x, incx, y, incy, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + + testinghelpers::ref_copyv( conjx, n, x, incx, y_ref, incy ); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + computediff( "y", n, y, y_ref, incy ); +} + +// Test-case logger : Used to print the test-case details based on parameters +template +class copyvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + bool is_memory_test = std::get<5>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/copyv/zcopyv_ukr.cpp b/gtestsuite/testsuite/ukr/copyv/zcopyv_ukr.cpp new file mode 100644 index 0000000000..83965b1f9e --- /dev/null +++ b/gtestsuite/testsuite/ukr/copyv/zcopyv_ukr.cpp @@ -0,0 +1,204 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_copyv_ukr.h" + +class zcopyvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zcopyvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zcopyvGeneric, UKR ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + zcopyv_ker_ft ukr_fp = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_copyv_ukr( ukr_fp, conjx, n, incx, incy, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zcopyv_zen_int kernel. + The code structure for bli_zcopyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Fringe loops : In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zcopyv_zen_int_unitStrides, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zcopyv_zen_int), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(16), // size n, for L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + // Testing the loops with combinations + gtint_t(80), // 5*L16 + gtint_t(88), // 5*L16 + L8 + gtint_t(92), // 5*L16 + L8 + L4 + gtint_t(94), // 5*L16 + L8 + L4 + L2 + gtint_t(95)), // 5*L16 + L8 + L4 + L2 + 1(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zcopyv_zen_int_nonUnitStrides, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zcopyv_zen_int), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_zcopyv_zen4_asm_avx512 kernel. + The code structure for bli_zcopyv_zen4_asm_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zcopyv_zen4_asm_avx512_unitStrides, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zcopyv_zen4_asm_avx512), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(128), // size n, for L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + gtint_t(1280), // 5*L256 + gtint_t(1408), // 5*L256 + L128 + gtint_t(1472), // 5*L256 + L128 + L32 + gtint_t(1504), // 5*L256 + L128 + L32 + L16 + gtint_t(1520), // 5*L258 + L128 + L32 + L16 + L8 + gtint_t(1528), // 5*L258 + L128 + L32 + L16 + L8 + L4 + gtint_t(1531)), // 5*L258 + L128 + L32 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zcopyv_zen4_asm_avx512_nonUnitStrides, + zcopyvGeneric, + ::testing::Combine( + ::testing::Values(bli_zcopyv_zen4_asm_avx512), + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(false, true) // is_memory_test + ), + ::copyvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/dotv/ddotv_ukr.cpp b/gtestsuite/testsuite/ukr/dotv/ddotv_ukr.cpp new file mode 100644 index 0000000000..19c4957423 --- /dev/null +++ b/gtestsuite/testsuite/ukr/dotv/ddotv_ukr.cpp @@ -0,0 +1,338 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_dotv_ukr.h" + +class ddotvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ddotvGeneric); + +// Tests using random integers as vector elements. +TEST_P( ddotvGeneric, UKR ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + ddotv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vec y is n,c + char conjy = std::get<2>(GetParam()); + // vector length: + gtint_t n = std::get<3>(GetParam()); + // stride size for x: + gtint_t incx = std::get<4>(GetParam()); + // stride size for y: + gtint_t incy = std::get<5>(GetParam()); + // enable/disable memory test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite level1/dotv/dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_dotv_ukr( ukr, conjx, conjy, n, incx, incy, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_ddotv_zen_int (AVX2) kernel. +/** + * Loops: + * L16 - handles 16 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int_unitStride, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int), + // conj(x): use n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): use n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // testing each loop individually. + gtint_t(32), // L16, executed twice + gtint_t(16), // L16 + gtint_t( 8), // LScalar, executed 8 times + gtint_t( 1), // LScalar + + // testing entire set of loops. + gtint_t(33), // L16 (executed twice) + LScalar + gtint_t(17), // L16 and LScalar + gtint_t(18) // L16 and LScalar (executed twice) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int_nonUnitPositiveStrides, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +// Tests for bli_ddotv_zen_int10 (AVX2) kernel. +/** + * Loops: + * L40 - Main loop, handles 40 elements + * L20 - handles 20 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * LScalar - leftover loop + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int10_unitStride, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int10), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // testing each loop individually. + gtint_t(80), // L40, executed twice + gtint_t(40), // L40 + gtint_t(20), // L20 + gtint_t(16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // LScalar + gtint_t( 1), // LScalar + + // testing entire set of loops starting from loop m to n. + gtint_t(73), // L40 through LScalar, excludes L16 + gtint_t(33), // L20 through LScalar, excludes L16 + gtint_t(13), // L8 through LScalar + gtint_t( 5), // L4 through LScalar + + // testing few combinations including L16. + gtint_t(77), // L40 + L20 + L16 + LScalar + gtint_t(76), // L40 + L20 + L16 + gtint_t(57), // L40 + L16 + LScalar + gtint_t(37) // L20 + L16 + LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int10_nonUnitPositiveStrides, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int10), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_ddotv_zen_int_avx512 (AVX512) kernel. +/** + * Loops & If conditions: + * L40 - Main loop, handles 40 elements + * L16 - handles 16 elements + * I8 - handles 8 elements + * IScalar - handles upto 8 leftover elements + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int_avx512_unitStride, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int_avx512), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // Individual Loop Tests + // testing each loop and if individually. + gtint_t(80), // L40, executed twice + gtint_t(40), // L40 + gtint_t(16), // L16 + gtint_t( 8), // I8 + gtint_t( 7), // IScalar + gtint_t( 6), // IScalar + gtint_t( 5), // IScalar + gtint_t( 4), // IScalar + gtint_t( 3), // IScalar + gtint_t( 2), // IScalar + gtint_t( 1), // IScalar + + // Waterfall Tests + // testing the entire set of loops and ifs. + gtint_t(65) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + bli_ddotv_zen_int_avx512_nonUnitPositiveStrides, + ddotvGeneric, + ::testing::Combine( + ::testing::Values(bli_ddotv_zen_int_avx512), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/dotv/test_dotv_ukr.h b/gtestsuite/testsuite/ukr/dotv/test_dotv_ukr.h new file mode 100644 index 0000000000..ca056edcc1 --- /dev/null +++ b/gtestsuite/testsuite/ukr/dotv/test_dotv_ukr.h @@ -0,0 +1,151 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/dotv/dotv.h" +#include "level1/ref_dotv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Microkernel test body for dotv operation. + */ + +template +static void test_dotv_ukr( FT ukr, char conjx, char conjy, gtint_t n, gtint_t incx, + gtint_t incy, double thresh, bool is_memory_test = false ) +{ + // Obtain and allocate memory for vectors. + T *x, *y, *y_ref; + + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ); + + testinghelpers::ProtectedBuffer x_buf( size_x * sizeof( T ), false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buf( size_y * sizeof( T ), false, is_memory_test ); + + // No redzones are required for y_ref buffer thus, we pass is_memory_test = false. + testinghelpers::ProtectedBuffer y_ref_buf( size_y * sizeof( T ), false, false ); + + // Acquire the first set of greenzones for x and y + x = ( T* )x_buf.greenzone_1; + y = ( T* )y_buf.greenzone_1; + y_ref = ( T* )y_ref_buf.greenzone_1; // For y_ref, there is no greenzone_2 + + // Initialize the vectors with random data. + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying the contents of y to y_ref, for comparision after computation. + memcpy( y_ref, y, size_y * sizeof( T ) ); + + T rho; + // Create a copy of rho so that we can check reference results. + T rho_ref; + + // conj? conversion to BLIS conjugate type. + conj_t blis_conjx, blis_conjy; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + testinghelpers::char_to_blis_conj( conjy, &blis_conjy ); + + // Add signal handler for Segmentation Faults. + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Invoking BLIS ukr. + // This will check for out of bounds access within first redzone. + ukr( blis_conjx, blis_conjy, n, x, incx, y, incy, &rho, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone. + x = ( T* )x_buf.greenzone_2; + y = ( T* )y_buf.greenzone_2; + + // Copy the data for x and y accordingly. + memcpy( x, x_buf.greenzone_1, size_x * sizeof( T ) ); + memcpy( y, y_ref_buf.greenzone_1, size_y * sizeof( T ) ); + + // Invoking BLIS ukr to check with the second redzone. + ukr( blis_conjx, blis_conjy, n, x, incx, y, incy, &rho, nullptr ); + } + } + catch( const std::exception& e ) + { + // Reset to default signal handler. + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case Segmentation Fault was detected. + FAIL() << "Memory Test Failed"; + } + + // Reset to default signal handler. + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Invoking the reference implementation to get reference results. + if constexpr (testinghelpers::type_info::is_real) + testinghelpers::ref_dotv( n, x, incx, y_ref, incy, &rho_ref ); + else + testinghelpers::ref_dotv( conjx, conjy, n, x, incx, y_ref, incy, &rho_ref ); + + // Compute component-wise error. + computediff( "rho", rho, rho_ref, thresh ); +} + + +// Test-case logger : Used to print the test-case details based on parameters +template +class dotvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + char conjy = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + gtint_t incx = std::get<4>(str.param); + gtint_t incy = std::get<5>(str.param); + bool is_memory_test = std::get<6>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_conjy_" + std::string(&conjy, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/dotv/zdotv_ukr.cpp b/gtestsuite/testsuite/ukr/dotv/zdotv_ukr.cpp new file mode 100644 index 0000000000..a16a66b619 --- /dev/null +++ b/gtestsuite/testsuite/ukr/dotv/zdotv_ukr.cpp @@ -0,0 +1,258 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_dotv_ukr.h" + +using T = dcomplex; +class zdotvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdotvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zdotvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + zdotv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether vec x is n,c + char conjx = std::get<1>(GetParam()); + // denotes whether vec y is n,c + char conjy = std::get<2>(GetParam()); + // vector length: + gtint_t n = std::get<3>(GetParam()); + // stride size for x: + gtint_t incx = std::get<4>(GetParam()); + // stride size for y: + gtint_t incy = std::get<5>(GetParam()); + // enable/disable memory test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite level1/dotv/dotv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = 2*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_dotv_ukr( ukr, conjx, conjy, n, incx, incy, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_zdotv_zen_int_avx512 (AVX512) kernel. +/** + * Loops & If conditions: + * L32 - Main loop, handles 32 elements + * L16 - handles 16 elements + * L12 - handles 12 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * LFringe - handles upto 4 leftover elements + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zdotv_zen_int_avx512_unitStride, + zdotvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdotv_zen_int_avx512), + // conj(x): use n (no_conjugate) or c (conjugate). + ::testing::Values('n', 'c'), + // conj(y): use n (no_conjugate) or c (conjugate). + ::testing::Values('n', 'c'), + // m: size of vector. + ::testing::Values( + // Individual Loop Tests + // testing each loop and if individually. + gtint_t(64), // L32, executed twice + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(12), // L12 + gtint_t( 8), // L8 + gtint_t( 4), // LFringe + gtint_t( 3), // LFringe + gtint_t( 2), // LFringe + gtint_t( 1), // LFringe + + // Waterfall Tests + // testing the entire set of loops and ifs. + gtint_t(92), // L32 * 2 + L16 + L12 + gtint_t(91), // L32 * 2 + L16 + L8 + L4 + LFringe * 3 + gtint_t(79) // L32 * 2 + L12 + LFringe + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zdotv_zen_int_avx512_nonUnitPositiveStrides, + zdotvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdotv_zen_int_avx512), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +// Tests for bli_zdotv_zen_int_avx512 (AVX512) kernel. +/** + * Loops & If conditions: + * L32 - Main loop, handles 32 elements + * L16 - handles 16 elements + * L12 - handles 12 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * LFringe - handles upto 4 leftover elements + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + DISABLED_bli_zdotv_zen4_asm_avx512_unitStride, + zdotvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdotv_zen4_asm_avx512), + // conj(x): use n (no_conjugate) or c (conjugate). + ::testing::Values('n', 'c'), + // conj(y): use n (no_conjugate) or c (conjugate). + ::testing::Values('n', 'c'), + // m: size of vector. + ::testing::Values( + // Individual Loop Tests + // testing each loop and if individually. + gtint_t(64), // L40, executed twice + gtint_t(32), // L40 + gtint_t(16), // L16 + gtint_t(12), // L12 + gtint_t( 8), // L8 + gtint_t( 4), // LFringe + gtint_t( 3), // LFringe + gtint_t( 2), // LFringe + gtint_t( 1), // LFringe + + // Waterfall Tests + // testing the entire set of loops and ifs. + gtint_t(92), // L32 * 2 + L16 + L12 + gtint_t(91), // L32 * 2 + L16 + L8 + L4 + LFringe * 3 + gtint_t(79) // L32 * 2 + L12 + LFringe + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zdotv_zen4_asm_avx512_nonUnitPositiveStrides, + zdotvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdotv_zen4_asm_avx512), + // conj(x): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // conj(y): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // is_memory_test: enable/disable memory tests + ::testing::Values( false, true ) + ), + ::dotvUKRPrint() + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/gemm/cgemm/cgemm_ukernel.cpp b/gtestsuite/testsuite/ukr/gemm/cgemm/cgemm_ukernel.cpp new file mode 100644 index 0000000000..f967787bb2 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/cgemm/cgemm_ukernel.cpp @@ -0,0 +1,723 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "common/testing_helpers.h" +#include "ukr/gemm/test_complex_gemm_ukr.h" + +/*******************************************************/ +/* SUP Kernel testing */ +/*******************************************************/ +class cgemmGenericSUP: + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmGenericSUP); + +TEST_P( cgemmGenericSUP, UKR ) +{ + using T = scomplex; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // storage scheme for C matrix + cgemmsup_ker_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + char transa = std::get<7>(GetParam()); // transa + char transb = (storageC == 'r')? 'n' : 't'; // transb + bool is_memory_test = std::get<8>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_complex_gemmsup_ukr (storageC, transa, transb, m, n, k, alpha, beta, thresh, kern_ptr, is_memory_test); +}// end of function + +class cgemmGenericSUPPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + scomplex alpha = std::get<3>(str.param); + scomplex beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + char transa = std::get<7>(str.param); + char transb = (storageC == 'r')? 'n' : 't'; + bool is_memory_test = std::get<8>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +/*********************************************************/ +/* Stroage Formats For SUP Kernels */ +/* A Matrix: Broadcast instruction is applied on Matrix */ +/* hence it can be row or col stored */ +/* trana = 'n' or 't' */ +/* B Matrix: Load instruction is applied on Matrix */ +/* hence it has to be row stored */ +/* When storage = r, transb = 'n' */ +/* When storage = c, transb = 't' */ +/* C Matrix: Supports row or col storage */ +/*********************************************************/ + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) + +/*************************************************/ +/***********Choosing values of m, n, k************/ +/* m is vectorised for 3 */ +/* - main kernel : 3, 6 (3x2) */ +/* - fringe case : 1, 2 */ +/* - main kernel and fringe case: */ +/* 4(3+1), 5(3+2), 7(3x2+1), 8(3x2+2) */ +/* n is vectorised for 4 and 2 */ +/* - main kernel : 4, 2, 1(gemv) */ +/* - main kernel and fringe case: */ +/* 3(2+1), 5(4+1), 6(4+2), 7(4+2+1) */ +/* k is unrolled 4 times */ +/* - main loop : 4, 8 */ +/* - fringe loop : 1, 2 */ +/* - main and fringe 5, 6, 9, 10 */ +/*************************************************/ + +/*Failures*/ +/* 1. blis_sol[i*ld + j] = (0.856704, 0.625597), ref_sol[i*ld + j] = (0.856718, 0.625608), i = 5, j = 0, thresh = 9.5367431640625e-06, error = 1.7269374438910745e-05 (144.86601257324219 * eps) +[ FAILED ] bli_cgemmsup_rv_zen_asm_3x8m/cgemmGenericSUP.FunctionalTest/StorageOfMatrix_r_transA_t_transB_n_m_6_n_8_k_4_alpha_3i4_beta_m7i6_mem_test_disabled, where GetParam() = (6, 8, 4, (3, 4.5), (-7.3, 6.7), 'r' (114, 0x72), 0x5576cdf96cc7, 't' (116, 0x74), 'n' (110, 0x6E), false) (0 ms) */ + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x8m, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x8m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x8m_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x8m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x4m, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x4m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x4m_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x4m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x2m, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x2m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x2m_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x2m), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x8n, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(4), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x8n_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(4), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +#if 0 +//Memtest fails +//Memtest diabled free(): invalid next size (fast) +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x8n, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(3), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x8n_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(3), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +#endif +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x8n, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x8n_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Range(gtint_t(1), gtint_t(16), 1), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x8n), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x4, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x4_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x2, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_3x2_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_3x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x8, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(8)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x8), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x8_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(8)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x8), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x8, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(8)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x8), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x8_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(8)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x8), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x4, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x4_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x4, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x4_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x4), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x2, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_2x2_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_2x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x2, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_cgemmsup_rv_zen_asm_1x2_alpha_beta, + cgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Values(gtint_t(10)), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -5.0}, scomplex{3, 4}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -5.0}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(bli_cgemmsup_rv_zen_asm_1x2), // cgemm_sup kernel + ::testing::Values('n', 't'), // transa + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericSUPPrint() + ); + +#endif + +/*******************************************************/ +/* Native Kernel testing */ +/*******************************************************/ +class cgemmGenericNat : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmGenericNat); +TEST_P( cgemmGenericNat, UKR ) +{ + using T = scomplex; + gtint_t k = std::get<0>(GetParam()); // dimension k + T alpha = std::get<1>(GetParam()); // alpha + T beta = std::get<2>(GetParam()); // beta + char storageC = std::get<3>(GetParam()); // indicates storage of all matrix operands + // Fix m and n to MR and NR respectively. + gtint_t m = std::get<4>(GetParam()); // m + gtint_t n = std::get<5>(GetParam()); // n + cgemm_ukr_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + bool is_memory_test = std::get<7>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmnat_ukr(storageC, m, n, k, alpha, beta, thresh, kern_ptr, is_memory_test); +}// end of function + +class cgemmGenericNatPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t k = std::get<0>(str.param); + scomplex alpha = std::get<1>(str.param); + scomplex beta = std::get<2>(str.param); + char storageC = std::get<3>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name ; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_cgemm_haswell_asm_3x8, + cgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of k + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{4.0, 0.0}, scomplex{0.0, -0.2}, scomplex{3.5, 4.5}), // alpha value + ::testing::Values(scomplex{0.0, 0.0}, scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, scomplex{-5.0, 0.0}, scomplex{0.0, -2.1}, scomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(3), // values of m + ::testing::Values(8), // values of n + ::testing::Values(bli_cgemm_haswell_asm_3x8), // cgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::cgemmGenericNatPrint() +); +#endif diff --git a/gtestsuite/testsuite/ukr/gemm/dgemm/dgemm_ukernel.cpp b/gtestsuite/testsuite/ukr/gemm/dgemm/dgemm_ukernel.cpp new file mode 100644 index 0000000000..4908e08ea3 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/dgemm/dgemm_ukernel.cpp @@ -0,0 +1,806 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "common/testing_helpers.h" +#include "ukr/gemm/test_gemm_ukr.h" + +/*******************************************************/ +/* SUP Kernel testing */ +/*******************************************************/ +class dgemmGenericSUP : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmGenericSUP); + +TEST_P( dgemmGenericSUP, sup_kernel) +{ + using T = double; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // storage scheme for C matrix + dgemmsup_ker_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + gtint_t MR = std::get<7>(GetParam()); // Micro-kernel tile size + char transa = std::get<8>(GetParam()); // transa + char transb = std::get<9>(GetParam()); // transb + bool row_pref = std::get<10>(GetParam()); // kernel transpose + bool is_memory_test = std::get<11>(GetParam()); // memory test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmsup_ukr(kern_ptr, transa, transb, m, n, k, alpha, beta, storageC, MR, row_pref, thresh, is_memory_test); + +}// end of function + + +class dgemmGenericSUPPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + double alpha = std::get<3>(str.param); + double beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + char transa = std::get<8>(str.param); + char transb = std::get<9>(str.param); + bool is_memory_test = std::get<11>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_haswell_asm_6x8m_row_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_dgemmsup_rv_haswell_asm_6x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_haswell_asm_6x8m_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rv_haswell_asm_6x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rd_haswell_asm_6x8m_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rd_haswell_asm_6x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_haswell_asm_6x8n_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rv_haswell_asm_6x8n), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_haswell_asm_6x8n_row_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_dgemmsup_rv_haswell_asm_6x8n), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rd_haswell_asm_6x8n_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rd_haswell_asm_6x8n), // dgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) + + INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_zen4_asm_24x8m_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(25), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rv_zen4_asm_24x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(8)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_zen4_asm_24x8m_row_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(25), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_dgemmsup_rv_zen4_asm_24x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(8)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN5) && defined(GTEST_AVX512) + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_zen5_asm_24x8m_col_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(25), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_dgemmsup_rv_zen5_asm_24x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(8)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_dgemmsup_rv_zen5_asm_24x8m_row_stored_c, + dgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(25), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_dgemmsup_rv_zen5_asm_24x8m), // dgemm_sup kernel + ::testing::Values(gtint_t(8)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false), // row preferred kernel? + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSUPPrint() + ); + +#endif + +/*******************************************************/ +/* Native Kernel testing */ +/*******************************************************/ +class dgemmGenericNat : +// public ::testing::TestWithParam> {}; +// k, alpha, beta, storage of c, m, n, dgemm native kernel, memory test + + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmGenericNat); + +TEST_P( dgemmGenericNat, native_kernel_testing) +{ + using T = double; + gtint_t k = std::get<0>(GetParam()); // dimension k + T alpha = std::get<1>(GetParam()); // alpha + T beta = std::get<2>(GetParam()); // beta + char storageC = std::get<3>(GetParam()); // indicates storage of all matrix operands + // Fix m and n to MR and NR respectively. + gtint_t m = std::get<4>(GetParam()); // m + gtint_t n = std::get<5>(GetParam()); // n + dgemm_ukr_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + bool is_memory_test = std::get<7>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmnat_ukr(storageC, m, n, k, alpha, beta, kern_ptr, thresh, is_memory_test); + +}// end of function + + + +class dgemmGenericNatPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t k = std::get<0>(str.param); + double alpha = std::get<1>(str.param); + double beta = std::get<2>(str.param); + char storageC = std::get<3>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha);; + str_name += "_beta_" + testinghelpers::get_value_string(beta);; + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_zen4_asm_32x6, + dgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(32), // values of m + ::testing::Values(6), // values of n + ::testing::Values(bli_dgemm_zen4_asm_32x6), + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericNatPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_zen4_asm_8x24, + dgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(8), // values of m + ::testing::Values(24), // values of n + ::testing::Values(bli_dgemm_zen4_asm_8x24), + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericNatPrint() +); +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_haswell_asm_6x8, + dgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(6), // values of m + ::testing::Values(8), // values of n + ::testing::Values(bli_dgemm_haswell_asm_6x8), + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericNatPrint() +); +#endif + +//Function pointer specific to dgemm kernel that handles +//special case where k=1. +typedef err_t (*gemm_k1_kernel) + ( + dim_t m, + dim_t n, + dim_t k, + double* alpha, + double* a, const inc_t lda, + double* b, const inc_t ldb, + double* beta, + double* c, const inc_t ldc + ); + +//Since AOCL BLAS is having separate kernel optimized to handle k=1 cases +//dgemm computation, a micro-kernel testing added that validates dgemm kernel +//for k=1 case. + +class dgemmGenericK1 : + public ::testing::TestWithParam> {}; +// k, alpha, beta, storage of c, m, n, dgemm k1 kernel, memory test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmGenericK1); + +TEST_P( dgemmGenericK1, k1_kernel_testing) +{ + using T = double; + gtint_t k = 1; + T alpha = std::get<0>(GetParam()); // alpha + T beta = std::get<1>(GetParam()); // beta + char storageC = std::get<2>(GetParam()); // indicates storage of all matrix operands + // Fix m and n to MR and NR respectively. + gtint_t m = std::get<3>(GetParam()); // dimension m + gtint_t n = std::get<4>(GetParam()); // dimension n + gemm_k1_kernel kern_ptr = std::get<5>(GetParam()); // Function pointer type for dgemm kernel + bool is_memory_test = std::get<6>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmk1_ukr(kern_ptr, m, n, k, storageC, alpha, beta, thresh, is_memory_test); + +}// end of function + + + +class dgemmGenericK1Print { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t k = 1; + double alpha = std::get<0>(str.param); + double beta = std::get<1>(str.param); + char storageC = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + bool is_memory_test = std::get<6>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_24x8_avx512_k1_nn, + dgemmGenericK1, + ::testing::Combine( + + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Values(bli_dgemm_24x8_avx512_k1_nn), + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericK1Print() +); + +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_8x6_avx2_k1_nn, + dgemmGenericK1, + ::testing::Combine( + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of n + ::testing::Values(bli_dgemm_8x6_avx2_k1_nn), + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericK1Print() +); +#endif + +#ifdef BLIS_ENABLE_SMALL_MATRIX + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +class dgemmGenericSmall : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmGenericSmall); + +TEST_P( dgemmGenericSmall, gemm_small) +{ + using T = double; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // indicates storage of all matrix operands + bool is_memory_test = std::get<6>(GetParam()); // memory test enable or disable + + + gtint_t lda = testinghelpers::get_leading_dimension( storageC, 'n', m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storageC, 'n', k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storageC, 'n', m, n, 0 ); + + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(BLIS_NO_TRANSPOSE, m, k, &m0_a, &n0_a); + bli_set_dims_with_trans(BLIS_NO_TRANSPOSE, k, n, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double*)&alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double*)&beta, &betao); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + if ( is_memory_test ) + { + srand(time(NULL)); + double *a, *b, *c, *cref = NULL; + // Allocate memory for A + testinghelpers::ProtectedBuffer a_buf( m * k * lda * sizeof(double), false, is_memory_test ); + // Allocate memory for B + testinghelpers::ProtectedBuffer b_buf( k * n * ldb * sizeof(double), false, is_memory_test ); + testinghelpers::ProtectedBuffer c_buf( m * n * ldc * sizeof(double), false, is_memory_test ); + + a = (double*)a_buf.greenzone_1; + b = (double*)b_buf.greenzone_1; + c = (double*)c_buf.greenzone_1; + + cref = (double*)malloc(m * n * ldc * sizeof(double)); + + testinghelpers::datagenerators::randomgenerators( -2, 8, 'c', m, k, (a), 'n', lda); + memset(b, rand() % 5, n*k*ldb*sizeof(double)); + memset(cref, rand() % 3, m*n*ldc*sizeof(double)); + memcpy(c, cref, m*n*ldc*sizeof(double)); + + bli_obj_init_finish(dt, m, k, (double*)a, 1, lda, &ao); + bli_obj_init_finish(dt, k, n, (double*)b, 1, ldb, &bo); + bli_obj_init_finish(dt, m, n, (double*)c, 1, ldc, &co); + + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &ao); + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &bo); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + bli_dgemm_small ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + if ( is_memory_test ) + { + a = (double*)a_buf.greenzone_2; + b = (double*)b_buf.greenzone_2; + c = (double*)c_buf.greenzone_2; + + memcpy(a, a_buf.greenzone_1, m * k * lda * sizeof(double)); + memcpy(b, b_buf.greenzone_1, n * k * ldb * sizeof(double)); + memcpy(c, cref, m * n * ldc * sizeof(double)); + + bli_dgemm_small ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // call reference implementation + testinghelpers::ref_gemm( storageC, 'n', 'n', m, n, k, alpha, + a, lda, b, ldb, beta, cref, ldc); + // Check component-wise error + computediff( "C", storageC, m, n, c, cref, ldc, thresh ); + + free(cref); + } + else + { + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 8, storageC, 'n', m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storageC, 'n', k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storageC, 'n', m, n, ldc ); + + std::vector c_ref(c); + + bli_obj_init_finish(dt, m0_a, n0_a, (double*)a.data(), 1, lda, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double*)b.data(), 1, ldb, &bo); + bli_obj_init_finish(dt, m, n, (double*)c.data(), 1, ldc, &co); + + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &ao); + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &bo); + + bli_dgemm_small ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + // call reference implementation + testinghelpers::ref_gemm( storageC, 'n', 'n', m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc); + // Check component-wise error + computediff( "C", storageC, m, n, c.data(), c_ref.data(), ldc, thresh ); + } + +}// end of function + + + +class dgemmGenericSmallPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + double alpha = std::get<3>(str.param); + double beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + bool is_memory_test = std::get<6>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemm_small, + dgemmGenericSmall, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(21), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(11), 1), // values of n + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage + ::testing::Values(true, false) // memory test + ), + ::dgemmGenericSmallPrint() + ); +#endif + +#endif diff --git a/gtestsuite/testsuite/ukr/gemm/sgemm/sgemm_ukernel.cpp b/gtestsuite/testsuite/ukr/gemm/sgemm/sgemm_ukernel.cpp new file mode 100644 index 0000000000..4439aa64b0 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/sgemm/sgemm_ukernel.cpp @@ -0,0 +1,675 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "common/testing_helpers.h" +#include "ukr/gemm/test_gemm_ukr.h" + +/*******************************************************/ +/* SUP Kernel testing */ +/*******************************************************/ +class sgemmGenericSUP : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmGenericSUP); + +TEST_P( sgemmGenericSUP, functionality_testing) +{ + using T = float; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // storage scheme for C matrix + sgemmsup_ker_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + gtint_t MR = std::get<7>(GetParam()); // Micro-kernel tile size + char transa = std::get<8>(GetParam()); // transa + char transb = std::get<9>(GetParam()); // transb + bool row_pref = std::get<10>(GetParam()); // kernel transpose + bool is_memory_test = std::get<11>(GetParam()); // memory test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmsup_ukr(kern_ptr, transa, transb, m, n, k, alpha, beta, storageC, MR, row_pref, thresh, is_memory_test); + +}// end of function + + +class sgemmGenericSUPPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + float alpha = std::get<3>(str.param); + float beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + char transa = std::get<8>(str.param); + char transb = std::get<9>(str.param); + bool is_memory_test = std::get<11>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x16m_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x16m), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x16m_col_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x16m), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rd_zen_asm_6x16m_col_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_sgemmsup_rd_zen_asm_6x16m), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x16n_col_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x16n), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x16n_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x16n), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rd_zen_asm_6x16n_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rd_zen_asm_6x16n), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x64m_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(65), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x64m_avx512), // sgemm_sup kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x64m_col_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(65), 1), // values of n + ::testing::Range(gtint_t(1), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x64m_avx512), // sgemm_sup_kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +/* + The bli_sgemmsup_rd_zen_asm_6x64m_avx512(standalone), accepts inputs with the + following contingency for n. + n <= NR, where NR is 64 + The code structure for the sgemm_sup rd kernels(m-var) are as follows: + In m direction : + Main kernel : Blocks of 6(L6_M) + Fringe kernels : 5 ... 1(L5_M ... L1_M) + In k direction : + Main loop : Blocks of 64(L64_K) + Fringe loop : Blocks of 32, 8, 1(L32_K ... L1_K) + In n direction : + Main kernel : NR = 64(L64_N) + Fringe kernels : With n being 48, 32(AVX512 kernels)(L48_N, L32_N) + With n being 16, 8, 4, 2, 1(Reusing AVX2 kernels)(L16_N ... L1_N) + + The inherent storage scheme format for the kernel is RRC, for C, A and B. + The testing interface allows for testing row-storage(inherent) and col-storage(operation transpose) + of C. We still need to pass the right transpose value pair for A and B, as per the kernel requirement. +*/ + +// Checking with row storage of C +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rd_zen_asm_6x64m_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), gtint_t(1)), // values of m(L6_M to L1_M) + ::testing::Values(gtint_t(64), // values of n, L64_N + gtint_t(48), // L48_N + gtint_t(32), // L32_N + gtint_t(8), // L8_N + gtint_t(7), // 7 * L1_N + gtint_t(63)), // Combination of fringe cases for N + ::testing::Values(gtint_t(64), // values of k, L64_K + gtint_t(32), // L32_K + gtint_t(8), // L8_K + gtint_t(7), // 7 * L1_K + gtint_t(256), // 4 * L64_K + gtint_t(303)), // Combination of main and fringe cases for K + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rd_zen_asm_6x64m_avx512), // sgemm_sup_kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa, has to be N for row storage + ::testing::Values('t'), // transb, has to be T for row storage + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +// Checking with col storage of C +// NOTE : Since we are inducing transpose at opertaion level, for code coverage, we +// have to interchange m and n instantiations +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rd_zen_asm_6x64m_col_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(64), // values of m, L64_N + gtint_t(48), // L48_N + gtint_t(32), // L32_N + gtint_t(8), // L8_N + gtint_t(7), // 7 * L1_N + gtint_t(63)), // Combination of fringe cases + ::testing::Range(gtint_t(1), gtint_t(7), gtint_t(1)), // values of n(L6_M to L1_M) + ::testing::Values(gtint_t(64), // values of k, L64_K + gtint_t(32), // L32_K + gtint_t(8), // L8_K + gtint_t(7), // 7 * L1_K + gtint_t(256), // 4 * L64_K + gtint_t(303)), // Combination of main and fringe cases for K + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_sgemmsup_rd_zen_asm_6x64m_avx512), // sgemm_sup_kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa, has to be T for row storage + ::testing::Values('n'), // transb, has to be N for row storage + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rv_zen_asm_6x64n_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(65), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rv_zen_asm_6x64n_avx512), // sgemm_sup_kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(true), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmsup_rd_zen_asm_6x64n_row_stored_c, + sgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(7), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(65), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_sgemmsup_rd_zen_asm_6x64n_avx512), // sgemm_sup_kernel + ::testing::Values(gtint_t(6)), // Micro kernel block MR + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false), // kernel pref + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericSUPPrint() + ); +#endif + +/*******************************************************/ +/* Native Kernel testing */ +/*******************************************************/ +class sgemmGenericNat : +// public ::testing::TestWithParam> {}; +//sgemm native kernel, k, alpha, beta, storage of c, m, n, memory test + + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmGenericNat); + +TEST_P( sgemmGenericNat, functionality_testing) +{ + using T = float; + gtint_t k = std::get<0>(GetParam()); // dimension k + T alpha = std::get<1>(GetParam()); // alpha + T beta = std::get<2>(GetParam()); // beta + char storageC = std::get<3>(GetParam()); // indicates storage of all matrix operands + // Fix m and n to MR and NR respectively. + gtint_t m = std::get<4>(GetParam()); // m + gtint_t n = std::get<5>(GetParam()); // n + sgemm_ukr_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + bool is_memory_test = std::get<7>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmnat_ukr(storageC, m, n, k, alpha, beta, kern_ptr, thresh, is_memory_test); + +}// end of function + + + +class sgemmGenericNatPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t k = std::get<0>(str.param); + float alpha = std::get<1>(str.param); + float beta = std::get<2>(str.param); + char storageC = std::get<3>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_sgemm_skx_asm_32x12_l2, + sgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(32), // values of m + ::testing::Values(12), // values of n + ::testing::Values(bli_sgemm_skx_asm_32x12_l2), + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericNatPrint() +); + + +#endif + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_sgemm_haswell_asm_6x16, + sgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(6), // values of m + ::testing::Values(16), // values of n + ::testing::Values(bli_sgemm_haswell_asm_6x16), + ::testing::Values(true, false) // memory test + ), + ::sgemmGenericNatPrint() +); +#endif + +#if 0 +/** + * sgemm_small microkernel testing disable because sgemm_small is static local + * function. Once it is made global, this testcase can be enabled. + * As of now for the compilation sake, this testcase is kept disabled. +*/ +#ifdef BLIS_ENABLE_SMALL_MATRIX + +class sgemmGenericSmallTest : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmGenericSmallTest); + +TEST_P( sgemmGenericSmallTest, gemm_small) +{ + using T = float; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // indicates storage of all matrix operands + + + gtint_t lda = testinghelpers::get_leading_dimension( storageC, 'n', m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storageC, 'n', k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storageC, 'n', m, n, 0 ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 8, storageC, 'n', m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storageC, 'n', k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storageC, 'n', m, n, ldc ); + + std::vector c_ref(c); + + const num_t dt = BLIS_FLOAT; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(BLIS_NO_TRANSPOSE, m, k, &m0_a, &n0_a); + bli_set_dims_with_trans(BLIS_NO_TRANSPOSE, k, n, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (float*)&alpha, &alphao); + bli_obj_init_finish_1x1(dt, (float*)&beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (float*)a.data(), 1, lda, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (float*)b.data(), 1, ldb, &bo); + bli_obj_init_finish(dt, m, n, (float*)c.data(), 1, ldc, &co); + + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &ao); + bli_obj_set_conjtrans(BLIS_NO_TRANSPOSE, &bo); + + + bli_sgemm_small ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + // call reference implementation + testinghelpers::ref_gemm( storageC, 'n', 'n', m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc); + + // Check component-wise error + computediff( "C", storageC, m, n, c.data(), c_ref.data(), ldc, thresh ); + +}// end of function + + + +class sgemmGenericSmallTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + float alpha = std::get<3>(str.param); + float beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + + return str_name; + } +}; + + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemm_small, + sgemmGenericSmallTest, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(71), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(21), 1), // values of n + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of k + ::testing::Values(2.0, 1.0, -1.0), // alpha value + ::testing::Values(1.0, 0.0, -1.0, 2.3), // beta value + ::testing::Values('c') // storage + ), + ::sgemmGenericSmallTestPrint() + ); + +#endif +#endif diff --git a/gtestsuite/testsuite/ukr/gemm/test_complex_gemm_ukr.h b/gtestsuite/testsuite/ukr/gemm/test_complex_gemm_ukr.h new file mode 100644 index 0000000000..7698c5da77 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/test_complex_gemm_ukr.h @@ -0,0 +1,540 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once +#include +#include +#include "level3/ref_gemm.h" +#include "inc/check_error.h" +#include "blis.h" +#include "common/testing_helpers.h" + +/**********************************************************************/ +/************ Code path when memory test is disabled **************/ +/* 1. Compute Leading dimension of all matrix based on */ +/* storage, size and trans parameters */ +/* 2. Compute size of matrices for which memory needs to be allocated */ +/* 3. Allocate memory for all matrices */ +/* 4. Initialise matrices with random numbers */ +/* 5. Copy blis output matrix content to reference output matrix */ +/* 6. Call blis micro kernel with output matrix */ +/* 7. Call reference kernel with reference output matrix */ +/* 8. Compute difference of blis and reference output */ +/* based on threshold set */ +/**********************************************************************/ +/************ Code path when memory test is enabled **************/ +/* 1. Compute Leading dimension of all matrix based on */ +/* storage, size and trans parameters */ +/* 2. Compute size of matrices for which memory needs to be allocated */ +/* 3. Allocate 2 set of memories for A, B, C matrix */ +/* green_zone1: Memory near red_zone1 */ +/* green_zone2: Memory near red_zone2 */ +/* 2 set of memory is required to check memory leaks */ +/* before starting of buffer or after end of buffer */ +/* 4. Initialise matrices with random numbers */ +/* 5. Call blis micro kernel with output matrix with green_zone1 ptr */ +/* 6. Call blis micro kernel again with green_zone2 ptr */ +/* 7. Failure is reported if there is out of bound read/write error */ +/* 8. Call reference kernel with reference output matrix to */ +/* check for any accuracy failures */ +/* 9. Compute difference of blis and reference output */ +/* based on threshold set */ +/**********************************************************************/ + +template +static void test_complex_gemmsup_ukr( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, gtint_t k, T alpha, T beta, double thresh, FT ukr_fp, bool is_memory_test = false ) +{ + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, 0 ); + + //---------------------------------------------------------- + // Compute size of Matrix: A, B, C + //---------------------------------------------------------- + gtint_t sizea = testinghelpers::matsize( storage, trnsa, m, k, lda ) * sizeof(T); + gtint_t sizeb = testinghelpers::matsize( storage, trnsb, k, n, ldb ) * sizeof(T); + gtint_t sizec = testinghelpers::matsize( storage, 'n', m, n, ldc ) * sizeof(T); + + // Allocate memory for Matrix: A, B, C, CRef + testinghelpers::ProtectedBuffer buf_a_ptrs( sizea, false, is_memory_test ); + testinghelpers::ProtectedBuffer buf_b_ptrs( sizeb, false , is_memory_test ); + testinghelpers::ProtectedBuffer buf_c_ptrs( sizec, false , is_memory_test ); + + /* No need to check for memory errors for reference code path, */ + /* hence is_memory_test is set to false */ + testinghelpers::ProtectedBuffer buf_cref_ptrs( sizec, false , false ); + + T* buf_a = (T*)buf_a_ptrs.greenzone_1; + T* buf_b = (T*)buf_b_ptrs.greenzone_1; + T* buf_c = (T*)buf_c_ptrs.greenzone_1; + T* buf_cref = (T*)buf_cref_ptrs.greenzone_1; + + testinghelpers::datagenerators::randomgenerators( -2, 8, storage, m, k, (T*)(buf_a), trnsa, lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, k, n, (T*)(buf_b), trnsb, ldb); + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, (T*)(buf_c), 'n', ldc); + + // Create a copy of c so that we can check reference results. + memcpy(buf_cref, buf_c, sizec); + + gtint_t rs_a = 1, cs_a = 1, rs_b = 1, cs_b = 1, rs_c = 1, cs_c = 1; + gtint_t rs_a0 = 1, cs_a0 = 1, rs_b0 = 1, cs_b0 = 1; + + if(storage == 'r') + { + rs_a = lda; + rs_b = ldb; + rs_c = ldc; + + cs_a = 1; + cs_b = 1; + cs_c = 1; + + rs_a0 = lda; + rs_b0 = ldb; + + cs_a0 = 1; + cs_b0 = 1; + } + else + { + cs_a = lda; + cs_b = ldb; + cs_c = ldc; + + rs_a = 1; + rs_b = 1; + rs_c = 1; + + cs_a0 = lda; + cs_b0 = ldb; + + rs_a0 = 1; + rs_b0 = 1; + } + + if(trnsb == 't' || trnsb == 'T') + { + rs_b = cs_b0; + cs_b = rs_b0; + } + + if(trnsa == 't' || trnsa == 'T') + { + rs_a = cs_a0; + cs_a = rs_a0; + } + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + auxinfo_t data; + //Panel stride update is required only for zen4 sup kernels + inc_t ps_a_use = (12 * rs_a); //12 = MR + bli_auxinfo_set_ps_a( ps_a_use, &data ); + + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + m, + n, + k, + &alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + &beta, + buf_c, rs_c, cs_c, + &data, + NULL + ); + + if ( is_memory_test ) + { + // set pointers to second buffer + buf_a = (T*)buf_a_ptrs.greenzone_2; + buf_b = (T*)buf_b_ptrs.greenzone_2; + buf_c = (T*)buf_c_ptrs.greenzone_2; + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, buf_a_ptrs.greenzone_1, sizea); + memcpy(buf_b, buf_b_ptrs.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, buf_cref, sizec); + + // second call to ukr + auxinfo_t data; + inc_t ps_a_use = (12 * rs_a); //12 = MR + bli_auxinfo_set_ps_a( ps_a_use, &data ); + + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + m, + n, + k, + &alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + &beta, + buf_c, rs_c, cs_c, + &data, + NULL + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // call reference implementation + testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, buf_cref, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, buf_c, buf_cref, ldc, thresh ); + +} + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_gemmnat_ukr( char storage, gtint_t m, gtint_t n, gtint_t k, T alpha, T beta, double thresh, FT ukr_fp, bool is_memory_test = false ) +{ + + /*************Memory requirement*****************************/ + /* General requirement of memory allocation: */ + /* Block Microkernel */ + /* A = MC * KC A = MR * k */ + /* B = NC * KC B = NR * k */ + /* C = MC * NC C = MR * NR */ + /* Native kernel works on packed buffer for A and B matrix */ + /* Memory requirement for input matrix for a block: */ + /* A = (MC + max(MR, NR)) * (KC + max(MR, NR)) */ + /* B = (NC + max(MR, NR)) * (KC + max(MR, NR)) */ + /* Memory requirement for input matrix for a microkernel: */ + /* A = max(MR, NR) * (k + max(MR, NR)) */ + /* B = max(MR, NR) * (k + max(MR, NR)) */ + /* MC, NC, KC - Cache block sizes */ + /* MR, NR - Micro kernel sizes */ + /* To support preloading feature inside microkernel, */ + /* allocation of extra memory is must */ + /************************************************************/ + + obj_t a, b; + num_t dt = BLIS_DCOMPLEX; + gtint_t maxmn = (std::max)(m,n); + bli_obj_create(dt, m, k, 1, m, &a); + bli_obj_create(dt, k, n, n, 1, &b); + + // Create test operands + // matrix A will be in col-storage + // matrix B will be in row-storage + // column * row = matrix -- rank-k update + + // Set matrix A dimensions + gtint_t rs = 1; + gtint_t cs = m; + gtint_t lda = cs; + gtint_t sizea = maxmn * (k+maxmn) * sizeof(T); + + // Set matrix B dimensions + rs = n; + cs = 1; + gtint_t ldb = rs; + gtint_t sizeb = (k+maxmn) * maxmn * sizeof(T); + + // Set matrix C dimensions + gtint_t ldc = m; + if(storage == 'r' || storage == 'R') + { + rs = n; + cs = 1; + ldc = rs; + } + else + { + rs = 1; + cs = m; + ldc = cs; + } + gtint_t sizec = m * n * sizeof(T); + + // Allocating aligned memory for A and B matrix as Native microkernel issues + // VMOVAPD which expects memory to be accessed to be aligned. + // Matrix C need not be aligned + testinghelpers::ProtectedBuffer buf_a_ptrs( sizea, true, is_memory_test ); + testinghelpers::ProtectedBuffer buf_b_ptrs( sizeb, true, is_memory_test ); + testinghelpers::ProtectedBuffer buf_c_ptrs( sizec, false, is_memory_test ); + + // Allocate memory for C Matrix used for reference computation + testinghelpers::ProtectedBuffer buf_c_ref_ptrs( sizec, false , false ); + + + T* buf_a = (T*)buf_a_ptrs.greenzone_1; + T* buf_b = (T*)buf_b_ptrs.greenzone_1; + T* buf_c = (T*)buf_c_ptrs.greenzone_1; + T* buf_cref = (T*)buf_c_ref_ptrs.greenzone_1; + + /* Initialize Matrices with random numbers */ + testinghelpers::datagenerators::randomgenerators( -2, 8, 'c', m, k, (T*)(buf_a), 'n', lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, 'r', k, n, (T*)(buf_b), 'n', ldb); + testinghelpers::datagenerators::randomgenerators( -5, 2, storage , m, n, (T*)(buf_c), 'n', ldc); + + // Create a copy of c so that we can check reference results. + memcpy(buf_cref, buf_c, sizec); + + /* Fill the auxinfo_t struct in case the micro-kernel uses it. */ + auxinfo_t data; + bli_auxinfo_set_ps_a(0, &data); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // call micro-kernel + ukr_fp ( + k, + &alpha, + buf_a, + buf_b, + &beta, + buf_c, + rs, + cs, + &data, + NULL + ); + if ( is_memory_test ) + { + // set pointers to second buffer + buf_a = (T*)buf_a_ptrs.greenzone_2; + buf_b = (T*)buf_b_ptrs.greenzone_2; + buf_c = (T*)buf_c_ptrs.greenzone_2; + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, buf_a_ptrs.greenzone_1, sizea); + memcpy(buf_b, buf_b_ptrs.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, buf_cref, sizec); + + ukr_fp ( + k, + &alpha, + buf_a, + buf_b, + &beta, + buf_c, + rs, + cs, + &data, + NULL + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // In native micro-kernel + // op(A) = No transpose & op(B) = transpose + // for column-storage + char transa = 'n'; + char transb = 't'; + + // The objective here is to make storage of all matrices same + // To do this we set transpose of A and B appropriately. + if (storage == 'r' || storage == 'R') + { + // if row-storage + transa = 't'; + transb = 'n'; + // because matrix A is created with col-storage + // and matrix B is created with row-storage + // Generally storage parameter in cblas signifies + // storage of all matrices A, B and C. + // since A is col-storage, A' will be row-storage + } + + // call reference implementation + testinghelpers::ref_gemm( storage, transa, transb, m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, (T*)buf_cref, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, (T*)buf_c, (T*)buf_cref, ldc, thresh ); + +} + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_gemmk1_ukr( FT ukr_fp, gtint_t m, gtint_t n, gtint_t k, char storage, T alpha, T beta, bool memory_test = false ) +{ + // Compute the leading dimensions of a, b, and c. + //char storage = storageC; + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, 0 ); + + //---------------------------------------------------------- + // Initialize matrices with random numbers + //---------------------------------------------------------- + gtint_t sizea = testinghelpers::matsize( storage, 'n', m, k, lda ) * sizeof(T); + gtint_t sizeb = testinghelpers::matsize( storage, 'n', k, n, ldb ) * sizeof(T); + gtint_t sizec = testinghelpers::matsize( storage, 'n', m, n, ldc ) * sizeof(T); + + testinghelpers::ProtectedBuffer mat_a(sizea, false, memory_test); + testinghelpers::ProtectedBuffer mat_b(sizeb, false, memory_test); + testinghelpers::ProtectedBuffer mat_c(sizec, false, memory_test); + testinghelpers::ProtectedBuffer mat_cref(sizec, false, false); + + T *buf_a = (T*)mat_a.greenzone_1; + T *buf_b = (T*)mat_b.greenzone_1; + T *buf_c = (T*)mat_c.greenzone_1; + T* buf_cref = (T*)mat_cref.greenzone_1; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) ||(buf_b == NULL) ||(buf_c == NULL) ||(buf_cref == NULL)) { + printf("Memory not allocated for input and output Matrix.\n"); + return ; + } + testinghelpers::datagenerators::randomgenerators( -2, 8, storage, m, k, (T*)(buf_a), 'n', lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, k, n, (T*)(buf_b), 'n', ldb); + testinghelpers::datagenerators::randomgenerators( -3, 5, storage, m, n, (T*)(buf_c), 'n', ldc); + + // Create a copy of c so that we can check reference results. + memcpy(buf_cref, buf_c, sizec); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // call micro-kernel + ukr_fp ( + m, + n, + k, + &alpha, + buf_a, + lda, + buf_b, + ldb, + &beta, + buf_c, + ldc + ); + + if(memory_test == true) + { + // set pointers to second buffer + buf_a = (T*)mat_a.greenzone_2; + buf_b = (T*)mat_b.greenzone_2; + buf_c = (T*)mat_c.greenzone_2; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) || (buf_b == NULL) || (buf_c == NULL)) { + printf("Memory not allocated for input or output Matrix for memory test.\n"); + return ; + } + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, mat_a.greenzone_1, sizea); + memcpy(buf_b, mat_b.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, buf_cref, sizec); + + // call micro-kernel + ukr_fp ( + m, + n, + k, + &alpha, + buf_a, + lda, + buf_b, + ldb, + &beta, + buf_c, + ldc + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else + thresh = (7*k+3)*testinghelpers::getEpsilon(); + + // call reference implementation + testinghelpers::ref_gemm( storage, 'n', 'n', m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, buf_cref, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, buf_c, buf_cref, ldc, thresh ); +} diff --git a/gtestsuite/testsuite/ukr/gemm/test_gemm_ukr.h b/gtestsuite/testsuite/ukr/gemm/test_gemm_ukr.h new file mode 100644 index 0000000000..ce81bd6b55 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/test_gemm_ukr.h @@ -0,0 +1,578 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once +#include "level3/ref_gemm.h" +#include "inc/check_error.h" +#include +#include +#include "blis.h" + +/** + * @brief Generic test body for gemm operation. + */ + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_gemmnat_ukr( + char storage, gtint_t m, gtint_t n, gtint_t k, T alpha, T beta, FT ukr_fp, double thresh, bool is_memory_test = false ) +{ + // In case of memory test: + // Allocate packed buffer size for Matrix A, B native kernel works on packed buffer + // Native kernel has preload or prebroadcase design + // If we allocate size required by dimension then memtest fails + obj_t a, b; + obj_t ap, bp; // for packed buffers + cntx_t* cntx; + num_t dt = BLIS_DOUBLE; + cntx = bli_gks_query_cntx(); + bli_obj_create(dt, m, k, 1, m, &a); + bli_obj_create(dt, k, n, n, 1, &b); + + bli_obj_create(dt, m, k, 1, m, &ap); + bli_obj_create(dt, k, n, n, 1, &bp); + + gtint_t sizea = bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_ROW_PANELS, + BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, + BLIS_MR, BLIS_KR, &a, &ap, cntx) * sizeof(T); + gtint_t sizeb = bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_COL_PANELS, + BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, + BLIS_KR, BLIS_NR, &b, &bp, cntx ) * sizeof(T); + + // Create test operands + // matrix A will be in col-storage + // matrix B will be in row-storage + // column * row = matrix -- rank-k update + + // Set matrix A dimensions + gtint_t rs = 1; + gtint_t cs = m; + gtint_t lda = cs; + //gtint_t sizea = m * k * sizeof(T); + + // Set matrix B dimensions + rs = n; + cs = 1; + gtint_t ldb = rs; + //gtint_t sizeb = k * n * sizeof(T); + + // Set matrix C dimensions + gtint_t ldc = m; + if(storage == 'r' || storage == 'R') + { + rs = n; + cs = 1; + ldc = rs; + } + else + { + rs = 1; + cs = m; + ldc = cs; + } + gtint_t sizec = m * n * sizeof(T); + + // Allocating aligned memory for A and B matrix as Native microkernel issues + // VMOVAPD which expects memory to be accessed to be aligned. + // Matrix C need not be aligned + testinghelpers::ProtectedBuffer buf_a_ptrs( sizea, true, is_memory_test ); + testinghelpers::ProtectedBuffer buf_b_ptrs( sizeb, true, is_memory_test ); + testinghelpers::ProtectedBuffer buf_c_ptrs( sizec, false, is_memory_test ); + + // Allocate memory for C Matrix used for reference computation + testinghelpers::ProtectedBuffer buf_c_ref_ptrs( sizec, false , false ); + + + T* buf_a = (T*)buf_a_ptrs.greenzone_1; + T* buf_b = (T*)buf_b_ptrs.greenzone_1; + T* buf_c = (T*)buf_c_ptrs.greenzone_1; + T* buf_cref = (T*)buf_c_ref_ptrs.greenzone_1; + + /* Initialize Matrices with random numbers */ + testinghelpers::datagenerators::randomgenerators( -2, 8, 'c', m, k, (T*)(buf_a), 'n', lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, 'r', k, n, (T*)(buf_b), 'n', ldb); + + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -5, 2, storage , m, n, (T*)(buf_c), 'n', ldc); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, (T*)(buf_c), 'n', ldc, testinghelpers::aocl_extreme() ); + } + + // Create a copy of c so that we can check reference results. + memcpy(buf_cref, buf_c, sizec); + + /* Fill the auxinfo_t struct in case the micro-kernel uses it. */ + auxinfo_t data; + bli_auxinfo_set_ps_a(0, &data); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // call micro-kernel + ukr_fp ( + k, + &alpha, + buf_a, + buf_b, + &beta, + buf_c, + rs, + cs, + &data, + NULL + ); + if ( is_memory_test ) + { + // set pointers to second buffer + buf_a = (T*)buf_a_ptrs.greenzone_2; + buf_b = (T*)buf_b_ptrs.greenzone_2; + buf_c = (T*)buf_c_ptrs.greenzone_2; + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, buf_a_ptrs.greenzone_1, sizea); + memcpy(buf_b, buf_b_ptrs.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, buf_cref, sizec); + + ukr_fp ( + k, + &alpha, + buf_a, + buf_b, + &beta, + buf_c, + rs, + cs, + &data, + NULL + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // In native micro-kernel + // op(A) = No transpose & op(B) = transpose + // for column-storage + char transa = 'n'; + char transb = 't'; + + // The objective here is to make storage of all matrices same + // To do this we set transpose of A and B appropriately. + if (storage == 'r' || storage == 'R') + { + // if row-storage + transa = 't'; + transb = 'n'; + // because matrix A is created with col-storage + // and matrix B is created with row-storage + // Generally storage parameter in cblas signifies + // storage of all matrices A, B and C. + // since A is col-storage, A' will be row-storage + } + + // call reference implementation + testinghelpers::ref_gemm( storage, transa, transb, m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, (T*)buf_cref, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, (T*)buf_c, (T*)buf_cref, ldc, thresh ); + +} + +// The function is templatized based on the datatype and function-pointer type to the kernel. +template +static void test_gemmk1_ukr( FT ukr_fp, gtint_t m, gtint_t n, gtint_t k, char storage, T alpha, T beta, double thresh, bool is_memory_test = false ) +{ + // Compute the leading dimensions of a, b, and c. + //char storage = storageC; + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, 0 ); + + //---------------------------------------------------------- + // Initialize matrices with random numbers + //---------------------------------------------------------- + gtint_t sizea = testinghelpers::matsize( storage, 'n', m, k, lda ) * sizeof(T); + gtint_t sizeb = testinghelpers::matsize( storage, 'n', k, n, ldb ) * sizeof(T); + gtint_t sizec = testinghelpers::matsize( storage, 'n', m, n, ldc ) * sizeof(T); + + testinghelpers::ProtectedBuffer mat_a(sizea, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_b(sizeb, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_c(sizec, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_cref(sizec, false, false); + + T *buf_a = (T*)mat_a.greenzone_1; + T *buf_b = (T*)mat_b.greenzone_1; + T *buf_c = (T*)mat_c.greenzone_1; + T* buf_cref = (T*)mat_cref.greenzone_1; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) ||(buf_b == NULL) ||(buf_c == NULL) ||(buf_cref == NULL)) { + printf("Memory not allocated for input and output Matrix.\n"); + return ; + } + testinghelpers::datagenerators::randomgenerators( -2, 8, storage, m, k, (T*)(buf_a), 'n', lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, k, n, (T*)(buf_b), 'n', ldb); + + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage , m, n, (T*)(buf_c), 'n', ldc); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, (T*)(buf_c), 'n', ldc, testinghelpers::aocl_extreme() ); + } + + // Create a copy of c so that we can check reference results. + memcpy(buf_cref, buf_c, sizec); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // call micro-kernel + ukr_fp ( + m, + n, + k, + &alpha, + buf_a, + lda, + buf_b, + ldb, + &beta, + buf_c, + ldc + ); + + if ( is_memory_test ) + { + // set pointers to second buffer + buf_a = (T*)mat_a.greenzone_2; + buf_b = (T*)mat_b.greenzone_2; + buf_c = (T*)mat_c.greenzone_2; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) || (buf_b == NULL) || (buf_c == NULL)) { + printf("Memory not allocated for input or output Matrix for memory test.\n"); + return ; + } + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, mat_a.greenzone_1, sizea); + memcpy(buf_b, mat_b.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, buf_cref, sizec); + + // call micro-kernel + ukr_fp ( + m, + n, + k, + &alpha, + buf_a, + lda, + buf_b, + ldb, + &beta, + buf_c, + ldc + ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // call reference implementation + testinghelpers::ref_gemm( storage, 'n', 'n', m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, buf_cref, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, buf_c, buf_cref, ldc, thresh ); +} + +template +static void test_gemmsup_ukr( FT ukr_fp, char trnsa, char trnsb, gtint_t m, gtint_t n, gtint_t k, T alpha, T beta, + char storageC, gtint_t MR, bool row_pref, double thresh, bool is_memory_test = false) +{ + // Compute the leading dimensions of a, b, and c. + char storage = storageC; + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, 0 ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, 0 ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, 0 ); + + //---------------------------------------------------------- + // Initialize matrices with random numbers + //---------------------------------------------------------- + gtint_t sizea = testinghelpers::matsize( storage, trnsa, m, k, lda ) * sizeof(T); + gtint_t sizeb = testinghelpers::matsize( storage, trnsb, k, n, ldb ) * sizeof(T); + gtint_t sizec = testinghelpers::matsize( storage, 'n', m, n, ldc ) * sizeof(T); + + testinghelpers::ProtectedBuffer mat_a(sizea, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_b(sizeb, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_c(sizec, false, is_memory_test); + testinghelpers::ProtectedBuffer mat_cref(sizec, false, false); + + T *buf_a = (T*)mat_a.greenzone_1; + T *buf_b = (T*)mat_b.greenzone_1; + T *buf_c = (T*)mat_c.greenzone_1; + T *ref_c = (T*)mat_cref.greenzone_1; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) ||(buf_b == NULL) ||(buf_c == NULL) ||(ref_c == NULL)) { + printf("Memory not allocated for input and output Matrix.\n"); + return ; + } + testinghelpers::datagenerators::randomgenerators( -2, 8, storage, m, k, (T*)(buf_a), trnsa, lda); + testinghelpers::datagenerators::randomgenerators( -5, 2, storage, k, n, (T*)(buf_b), trnsb, ldb); + + if (beta != testinghelpers::ZERO()) + testinghelpers::datagenerators::randomgenerators( -3, 5, storage , m, n, (T*)(buf_c), 'n', ldc); + else + { + // Matrix C should not be read, only set. + testinghelpers::set_matrix( storage, m, n, (T*)(buf_c), 'n', ldc, testinghelpers::aocl_extreme() ); + } + + // Create a copy of c so that we can check reference results. + memset(buf_c, 0, sizec); + memset(ref_c, 0, sizec); + inc_t str_id = 0; + gtint_t rs_a = 1, cs_a = 1, rs_b = 1, cs_b = 1, rs_c = 1, cs_c = 1; + gtint_t rs_a0 = 1, cs_a0 = 1, rs_b0 = 1, cs_b0 = 1; + + if(storage == 'r') + { + rs_a = lda; + rs_b = ldb; + rs_c = ldc; + + cs_a = 1; + cs_b = 1; + cs_c = 1; + + rs_a0 = lda; + rs_b0 = ldb; + + cs_a0 = 1; + cs_b0 = 1; + } + else + { + cs_a = lda; + cs_b = ldb; + cs_c = ldc; + + rs_a = 1; + rs_b = 1; + rs_c = 1; + + cs_a0 = lda; + cs_b0 = ldb; + + rs_a0 = 1; + rs_b0 = 1; + } + + if(trnsb == 'n' || trnsb == 'N') + { + str_id = 1 * (rs_b == 1); //1st bit + } + else if(trnsb == 't' || trnsb == 'T') + { + str_id = 1 * (cs_b == 1); //1st bit + rs_b = cs_b0; + cs_b = rs_b0; + } + + if(trnsa == 'n' || trnsa == 'N') + { + str_id |= ((1 * (rs_a == 1)) << 1); //2nd bit + } + else if(trnsa == 't' || trnsa == 'T') + { + str_id |= ((1 * (cs_a == 1)) << 1); //2nd bit + rs_a = cs_a0; + cs_a = rs_a0; + } + + bool is_primary = false; + + str_id |= ((1 * (rs_c == 1)) << 2); //3rd bit + + if(str_id == 0 || str_id == 1 || str_id == 2 || str_id == 4) + { + is_primary = true; + } + + auxinfo_t data; + inc_t ps_a_use = (MR * rs_a); + bli_auxinfo_set_ps_a( ps_a_use, &data ); + + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + if(is_primary == false && row_pref == true) + { + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n, + m, + k, + &alpha, + buf_b, cs_b, rs_b, + buf_a, cs_a, rs_a, + &beta, + buf_c, cs_c, rs_c, + &data, + NULL + ); + } + else + { + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + m, + n, + k, + &alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + &beta, + buf_c, rs_c, cs_c, + &data, + NULL + ); + } + + if ( is_memory_test ) + { + // set pointers to second buffer + buf_a = (T*)mat_a.greenzone_2; + buf_b = (T*)mat_b.greenzone_2; + buf_c = (T*)mat_c.greenzone_2; + + // Check if the memory has been successfully allocated + if ((buf_a == NULL) || (buf_b == NULL) || (buf_c == NULL)) { + printf("Memory not allocated for input or output Matrix for memory test.\n"); + return ; + } + + // copy data from 1st buffer of A and B to second buffer + memcpy(buf_a, mat_a.greenzone_1, sizea); + memcpy(buf_b, mat_b.greenzone_1, sizeb); + + //buf_c_ptrs.greenzone_1 has been updated with output from previous + // gemm call, hence use buf_cref + memcpy(buf_c, ref_c, sizec); + + if(is_primary == false && row_pref == true) + { + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n, + m, + k, + &alpha, + buf_b, cs_b, rs_b, + buf_a, cs_a, rs_a, + &beta, + buf_c, cs_c, rs_c, + &data, + NULL + ); + } + else + { + ukr_fp( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + m, + n, + k, + &alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + &beta, + buf_c, rs_c, cs_c, + &data, + NULL + ); + } + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // call reference implementation + testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha, + buf_a, lda, buf_b, ldb, beta, ref_c, ldc); + + // Check component-wise error + computediff( "C", storage, m, n, buf_c, ref_c, ldc, thresh ); +} diff --git a/gtestsuite/testsuite/ukr/gemm/zgemm/zgemm_ukernel.cpp b/gtestsuite/testsuite/ukr/gemm/zgemm/zgemm_ukernel.cpp new file mode 100644 index 0000000000..f633a41444 --- /dev/null +++ b/gtestsuite/testsuite/ukr/gemm/zgemm/zgemm_ukernel.cpp @@ -0,0 +1,1303 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "common/testing_helpers.h" +#include "ukr/gemm/test_complex_gemm_ukr.h" + +/*******************************************************/ +/* SUP Kernel testing */ +/*******************************************************/ +class zgemmGenericSUP: + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmGenericSUP); + +TEST_P( zgemmGenericSUP, UKR ) +{ + using T = dcomplex; + gtint_t m = std::get<0>(GetParam()); // dimension m + gtint_t n = std::get<1>(GetParam()); // dimension n + gtint_t k = std::get<2>(GetParam()); // dimension k + T alpha = std::get<3>(GetParam()); // alpha + T beta = std::get<4>(GetParam()); // beta + char storageC = std::get<5>(GetParam()); // storage scheme for C matrix + zgemmsup_ker_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + char transa = std::get<7>(GetParam()); // transa + char transb = std::get<8>(GetParam()); // transb + bool is_memory_test = std::get<9>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_complex_gemmsup_ukr(storageC, transa, transb, m, n, k, alpha, beta, thresh, kern_ptr, is_memory_test); +}// end of function + +class zgemmGenericSUPPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t m = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t k = std::get<2>(str.param); + dcomplex alpha = std::get<3>(str.param); + dcomplex beta = std::get<4>(str.param); + char storageC = std::get<5>(str.param); + char transa = std::get<7>(str.param); + char transb = std::get<8>(str.param); + bool is_memory_test = std::get<9>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_transb_" + std::string(&transb, 1); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x4m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(10), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(15), 1), // values of k + //alpha values dcomplex{0.0, 0.0} failure observed + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -5.0}, dcomplex{3, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -5.0}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x4m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x4_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 5.0}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 5.0}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x4), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x4_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(18), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 5.5}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 5.4}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x4), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x2m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(13), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x2m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x2_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(5), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0,15.0}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x2), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x2_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 12}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 13}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x2), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x2_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(8), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 6}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x2), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x4m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(14), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(22), 1), // values of k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -15.0}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x4m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x2m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(14), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 3.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x2m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(3)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.4}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x4_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(7), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 19.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.99}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x4_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(8), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0},dcomplex{0.0, 1.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.5}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -1.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(8), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_3x4m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(12), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 2.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_3x4m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_3x2m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(11), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.19}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_3x2m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_3x4n_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(4), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(10), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(16),1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.0}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 2.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_3x4n), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_2x4n_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Range(gtint_t(1), gtint_t(12), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.23}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_2x4n), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_2x4_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.34}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 2.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_2x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_1x4_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(9), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.56}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 21.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_1x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_1x2_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(8), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.99}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -21.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_1x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rd_zen_asm_2x2_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 91.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rd_zen_asm_2x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x4n_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(4), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(15), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -2}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x4n), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x4n_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Range(gtint_t(1), gtint_t(13), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 8.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x4n), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x4n_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Range(gtint_t(1), gtint_t(8), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 5.6}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x4n), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('t'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_3x4n_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(4), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(18), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0,.0}, dcomplex{0.0, 2.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_3x4n), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_2x4n_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Range(gtint_t(1), gtint_t(6), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -5.6}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_2x4n), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_rv_zen_asm_1x4n_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(1)), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -1.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_rv_zen_asm_1x4n), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x4m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(28), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -8}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x4m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x3m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Values(gtint_t(3)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x3m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x2m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(13), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -0.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -21.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x2m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x1m_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(25), 1), // values of m + ::testing::Values(gtint_t(1)), // values of n + ::testing::Range(gtint_t(0), gtint_t(22), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -31.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.4}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x1m), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_8x4_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(8)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 8}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_8x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_8x3_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(8)), // values of m + ::testing::Values(gtint_t(3)), // values of n + ::testing::Range(gtint_t(0), gtint_t(16), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.2}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -1.8}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_8x3), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_8x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(8)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_8x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_8x1_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(8)), // values of m + ::testing::Values(gtint_t(1)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_8x1), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_4x4_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(4)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(9), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_4x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_4x3_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(4)), // values of m + ::testing::Values(gtint_t(3)), // values of n + ::testing::Range(gtint_t(0), gtint_t(19), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.5}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_4x3), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_4x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(4)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -19}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_4x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_4x1_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(4)), // values of m + ::testing::Values(gtint_t(1)), // values of n + ::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -19}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_4x1), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_2x4_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(4)), // values of n + ::testing::Range(gtint_t(0), gtint_t(16), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.8}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_2x4), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_2x3_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(3)), // values of n + ::testing::Range(gtint_t(0), gtint_t(5), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 18}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_2x3), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_2x2_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(9), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -19}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_2x2), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_2x1_col_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Values(gtint_t(2)), // values of m + ::testing::Values(gtint_t(1)), // values of n + ::testing::Range(gtint_t(0), gtint_t(15), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('c'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_2x1), // zgemm_sup kernel + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x4m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(13), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // values of n + ::testing::Range(gtint_t(0), gtint_t(14), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 7}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x4m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x3m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(33), 1), // values of m + ::testing::Values(gtint_t(3)), // values of n + ::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -9.7}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.2}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x3m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x2m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(21), 1), // values of m + ::testing::Values(gtint_t(2)), // values of n + ::testing::Range(gtint_t(0), gtint_t(12), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 1.4}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 8.9}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x2m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); + + INSTANTIATE_TEST_SUITE_P ( + bli_zgemmsup_cv_zen4_asm_12x1m_row_stored_c, + zgemmGenericSUP, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of m + ::testing::Values(gtint_t(1)), // values of n + ::testing::Range(gtint_t(0), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 9}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 19}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r'), // storage of c + ::testing::Values(bli_zgemmsup_cv_zen4_asm_12x1m), // zgemm_sup kernel + ::testing::Values('t'), // transa + ::testing::Values('n'), // transb + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericSUPPrint() + ); +#endif + +/*******************************************************/ +/* Native Kernel testing */ +/*******************************************************/ +class zgemmGenericNat : + public ::testing::TestWithParam> {}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmGenericNat); +TEST_P( zgemmGenericNat, MicroKernelTest) +{ + using T = dcomplex; + gtint_t k = std::get<0>(GetParam()); // dimension k + T alpha = std::get<1>(GetParam()); // alpha + T beta = std::get<2>(GetParam()); // beta + char storageC = std::get<3>(GetParam()); // indicates storage of all matrix operands + // Fix m and n to MR and NR respectively. + gtint_t m = std::get<4>(GetParam()); // m + gtint_t n = std::get<5>(GetParam()); // n + zgemm_ukr_ft kern_ptr = std::get<6>(GetParam()); // pointer to the gemm kernel + bool is_memory_test = std::get<7>(GetParam()); // is_memory_test + + // Set the threshold for the errors: + // Check gtestsuite gemm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0) + thresh = 0.0; + else if ((alpha == testinghelpers::ZERO() || k == 0) && (beta == testinghelpers::ZERO() || + beta == testinghelpers::ONE())) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO()) + thresh = testinghelpers::getEpsilon(); + else + thresh = (3*k+1)*testinghelpers::getEpsilon(); + + test_gemmnat_ukr(storageC, m, n, k, alpha, beta, thresh, kern_ptr, is_memory_test); + +}// end of function + +class zgemmGenericNatPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + + gtint_t k = std::get<0>(str.param); + dcomplex alpha = std::get<1>(str.param); + dcomplex beta = std::get<2>(str.param); + char storageC = std::get<3>(str.param); + bool is_memory_test = std::get<7>(str.param); + + std::string str_name; + str_name += "_stor_" + std::string(&storageC, 1); + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen4_asm_12x4, + zgemmGenericNat, + ::testing::Combine( //Failure observed for this case zgemmnat_ukr_1_a0pi2_bm7pi6_r + ::testing::Range(gtint_t(1), gtint_t(15), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.0}, dcomplex{-3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(12), // values of m + ::testing::Values(4), // values of n + ::testing::Values(bli_zgemm_zen4_asm_12x4), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen4_asm_12x4_k0, + zgemmGenericNat, + ::testing::Combine( //Failure observed for this case zgemmnat_ukr_1_a0pi2_bm7pi6_r + ::testing::Values(gtint_t(0)), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 1.0}, dcomplex{-3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(12), // values of m + ::testing::Values(4), // values of n + ::testing::Values(bli_zgemm_zen4_asm_12x4), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); + +/*Kernel reqired for trsm computation*/ +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen4_asm_4x12, + zgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 3.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(4), // values of m + ::testing::Values(12), // values of n + ::testing::Values(bli_zgemm_zen4_asm_4x12), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen4_asm_4x12_k0, + zgemmGenericNat, + ::testing::Combine( + ::testing::Values(gtint_t(0)), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, 2.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, 3.3}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(4), // values of m + ::testing::Values(12), // values of n + ::testing::Values(bli_zgemm_zen4_asm_4x12), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); +#endif + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_haswell_asm_3x4, + zgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(20), 1), // values of k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -0.2}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2.1}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(3), // values of m + ::testing::Values(4), // values of n + ::testing::Values(bli_zgemm_haswell_asm_3x4), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_haswell_asm_3x4_k0, + zgemmGenericNat, + ::testing::Combine( + ::testing::Values(gtint_t(0)), // values of k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -0.2}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2.1}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(3), // values of m + ::testing::Values(4), // values of n + ::testing::Values(bli_zgemm_haswell_asm_3x4), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/*Kernel reqired for trsm computation*/ +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen_asm_2x6, + zgemmGenericNat, + ::testing::Combine( + ::testing::Range(gtint_t(1), gtint_t(10), 1), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -0.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2.0}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(2), // values of m + ::testing::Values(6), // values of n + ::testing::Values(bli_zgemm_zen_asm_2x6), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_zen_asm_2x6_k0, + zgemmGenericNat, + ::testing::Combine( + ::testing::Values(gtint_t(0)), // values of k + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{4.0, 0.0}, dcomplex{0.0, -0.3}, dcomplex{3.5, 4.5}), // alpha value + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, dcomplex{-5.0, 0.0}, dcomplex{0.0, -2.0}, dcomplex{-7.3, 6.7}), // beta value + ::testing::Values('r', 'c'), // storage + ::testing::Values(2), // values of m + ::testing::Values(6), // values of n + ::testing::Values(bli_zgemm_zen_asm_2x6), // zgemm_nat kernel + ::testing::Values(false, true) // is_memory_test + ), + ::zgemmGenericNatPrint() +); +#endif + +// Function pointer specific to zgemm kernel that handles +// special case where k=1. +typedef err_t (*zgemm_k1_kernel) + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ); + +// AOCL-BLAS has a set of kernels(AVX2 and AVX512) that separately handle +// k=1 cases for ZGEMM. Thus, we need to define a test-fixture class for testing +// these kernels +class zgemmUkrk1 : + public ::testing::TestWithParam> {}; // is_mem_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmUkrk1); + +TEST_P(zgemmUkrk1, FunctionalTest) +{ + using T = dcomplex; + gtint_t k = 1; + T alpha = std::get<0>(GetParam()); // alpha + T beta = std::get<1>(GetParam()); // beta + char storage = std::get<2>(GetParam()); // indicates storage of all matrix operands + gtint_t m = std::get<3>(GetParam()); // m + gtint_t n = std::get<4>(GetParam()); // n + zgemm_k1_kernel kern_ptr = std::get<5>(GetParam()); // kernel address + bool memory_test = std::get<6>(GetParam()); // is_mem_test + + // Call to the testing interface(specific to k=1 cases) + test_gemmk1_ukr(kern_ptr, m, n, k, storage, alpha, beta, memory_test); +} + +class zgemmUkrk1Print { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t k = 1; + dcomplex alpha = std::get<0>(str.param); + dcomplex beta = std::get<1>(str.param); + char storage = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + bool memory_test = std::get<6>(str.param); + + std::string str_name; + str_name += "_k_" + std::to_string(k); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_beta_" + testinghelpers::get_value_string(beta); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name = str_name + "_" + storage; + str_name += ( memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_16x4_avx512_k1_nn, + zgemmUkrk1, + ::testing::Combine( + + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 0.0}, dcomplex{1.2, 2.3}), // alpha value + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 0.0}, dcomplex{1.2, 2.3}), // beta value + ::testing::Values('c'), // storage + ::testing::Range(gtint_t(1), gtint_t(33), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Values(bli_zgemm_16x4_avx512_k1_nn), + ::testing::Values(true, false) // memory test + ), + ::zgemmUkrk1Print() +); +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemm_4x4_avx2_k1_nn, + zgemmUkrk1, + ::testing::Combine( + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 0.0}, dcomplex{1.2, 2.3}), // alpha value + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 0.0}, dcomplex{1.2, 2.3}), // beta value + ::testing::Values('c'), // storage + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // values of n + ::testing::Values(bli_zgemm_4x4_avx2_k1_nn), + ::testing::Values(true, false) // memory test + ), + ::zgemmUkrk1Print() +); +#endif diff --git a/gtestsuite/testsuite/ukr/nrm2/dnrm2_ukr.cpp b/gtestsuite/testsuite/ukr/nrm2/dnrm2_ukr.cpp new file mode 100644 index 0000000000..6aec5bdc46 --- /dev/null +++ b/gtestsuite/testsuite/ukr/nrm2/dnrm2_ukr.cpp @@ -0,0 +1,174 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2_ukr.h" + +using T = double; +using RT = typename testinghelpers::type_info::real_type; + +class dnrm2Generic : + public ::testing::TestWithParam, // Kernel pointer type + gtint_t, // n + gtint_t, // incx + bool>> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dnrm2Generic); + +TEST_P( dnrm2Generic, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + nrm2_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_dnorm2fv_unb_var1_avx2 kernel. + The code structure for bli_dnorm2fv_unb_var1_avx2( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dnorm2fv_unb_var1_avx2_unitStrides, + dnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dnorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(8), // size n, for L8 + gtint_t(4), // L4 + gtint_t(3), // 3(LScalar) + gtint_t(40), // 5*L8 + gtint_t(43), // 5*L8 + 3(LScalar) + gtint_t(44), // 5*L8 + L4 + gtint_t(47)), // 5*L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_dnorm2fv_unb_var1_avx2_nonUnitStrides, + dnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dnorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(25), // n, size of the vector + gtint_t(41), + gtint_t(17), + gtint_t(9)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_dnorm2fv_unb_var1_avx512 kernel. + The code structure for bli_dnorm2fv_unb_var1_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 16 --> L16 + In blocks of 8 --> L8 + Masked loop --> LMask + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dnorm2fv_unb_var1_avx512_unitStrides, + dnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dnorm2fv_unb_var1_avx512), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(32), // size n, for L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LMask + gtint_t(160), // 5*L32 + gtint_t(176), // 5*L32 + L16 + gtint_t(184), // 5*L32 + L16 + L8 + gtint_t(191)), // 5*L32 + L16 + L8 + 7(LMask) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_dnorm2fv_unb_var1_avx512_nonUnitStrides, + dnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dnorm2fv_unb_var1_avx512), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(25), // n, size of the vector + gtint_t(41), + gtint_t(17), + gtint_t(9)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/nrm2/dznrm2_ukr.cpp b/gtestsuite/testsuite/ukr/nrm2/dznrm2_ukr.cpp new file mode 100644 index 0000000000..4387ad415e --- /dev/null +++ b/gtestsuite/testsuite/ukr/nrm2/dznrm2_ukr.cpp @@ -0,0 +1,121 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2_ukr.h" + +using T = dcomplex; +using RT = typename testinghelpers::type_info::real_type; + +class dznrm2Generic : + public ::testing::TestWithParam, // Kernel pointer type + gtint_t, // n + gtint_t, // incx + bool>> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dznrm2Generic); + +TEST_P( dznrm2Generic, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + nrm2_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_dznorm2fv_unb_var1_avx2 kernel. + The code structure for bli_dznorm2fv_unb_var1_avx2( ... ) is as follows : + For unit strides : + Main loop : In blocks of 4 --> L4 + Fringe loops : In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dznorm2fv_unb_var1_avx2_unitStrides, + dznrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dznorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(4), // size n, for L4 + gtint_t(2), // L2 + gtint_t(1), // 1(LScalar) + gtint_t(40), // 10*L4 + gtint_t(41), // 10*L4 + 1(LScalar) + gtint_t(42), // 10*L4 + L2 + gtint_t(43)), // 10*L4 + L2 + 1(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_dznorm2fv_unb_var1_avx2_nonUnitStrides, + dznrm2Generic, + ::testing::Combine( + ::testing::Values(bli_dznorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(25), // n, size of the vector + gtint_t(41), + gtint_t(17), + gtint_t(9)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/nrm2/scnrm2_ukr.cpp b/gtestsuite/testsuite/ukr/nrm2/scnrm2_ukr.cpp new file mode 100644 index 0000000000..160b3a91c4 --- /dev/null +++ b/gtestsuite/testsuite/ukr/nrm2/scnrm2_ukr.cpp @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2_ukr.h" + +using T = scomplex; +using RT = typename testinghelpers::type_info::real_type; + +class scnrm2Generic : + public ::testing::TestWithParam, // Kernel pointer type + gtint_t, // n + gtint_t, // incx + bool>> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(scnrm2Generic); + +TEST_P( scnrm2Generic, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + nrm2_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_scnorm2fv_unb_var1_avx2 kernel. + The code structure for bli_scnorm2fv_unb_var1_avx2( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Fringe loops : In blocks of 12 --> L12 + In blocks of 8 --> L8 + In blocks of 4 --> L4(Currently disabled) + Element-wise loop --> LScalar + NOTE : The code to handle unit-strides is taken only if n >= 64. + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_scnorm2fv_unb_var1_avx2_unitStrides, + scnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_scnorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L16 + gtint_t(76), // 4*L16 + L12 + gtint_t(72), // 4*L16 + L8 + gtint_t(68), // 4*L16 + L4 + gtint_t(67)), // 4*L16 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_scnorm2fv_unb_var1_avx2_nonUnitStrides, + scnrm2Generic, + ::testing::Combine( + ::testing::Values(bli_scnorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(25), // n, size of the vector + gtint_t(41), + gtint_t(17), + gtint_t(9)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/nrm2/snrm2_ukr.cpp b/gtestsuite/testsuite/ukr/nrm2/snrm2_ukr.cpp new file mode 100644 index 0000000000..731644b5e3 --- /dev/null +++ b/gtestsuite/testsuite/ukr/nrm2/snrm2_ukr.cpp @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2_ukr.h" + +using T = float; +using RT = typename testinghelpers::type_info::real_type; + +class snrm2Generic : + public ::testing::TestWithParam, // Kernel pointer type + gtint_t, // n + gtint_t, // incx + bool>> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(snrm2Generic); + +TEST_P( snrm2Generic, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + nrm2_ker_ft ukr_fp = std::get<0>(GetParam()); + // vector length + gtint_t n = std::get<1>(GetParam()); + // stride size for x + gtint_t incx = std::get<2>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2_ukr( ukr_fp, n, incx, thresh, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_snorm2fv_unb_var1_avx2 kernel. + The code structure for bli_snorm2fv_unb_var1_avx2( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 24 --> L24 + In blocks of 16 --> L16 + In blocks of 8 --> L8(Currently disabled) + Element-wise loop --> LScalar + NOTE : The code to handle unit-strides is taken only if n >= 64. + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_snorm2fv_unb_var1_avx2_unitStrides, + snrm2Generic, + ::testing::Combine( + ::testing::Values(bli_snorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L32 + gtint_t(88), // 2*L32 + L24 + gtint_t(80), // 2*L32 + L16 + gtint_t(72), // 2*L32 + L8 + gtint_t(71)), // 2*L32 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); + +// Unit testing with non-unit strides. +INSTANTIATE_TEST_SUITE_P( + bli_snorm2fv_unb_var1_avx2_nonUnitStrides, + snrm2Generic, + ::testing::Combine( + ::testing::Values(bli_snorm2fv_unb_var1_avx2), // ukr function + // m size of vector + ::testing::Values(// Testing the loops standalone + gtint_t(25), // n, size of the vector + gtint_t(41), + gtint_t(17), + gtint_t(9)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(true, false) // is_memory_test + ), + ::nrm2UKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/nrm2/test_nrm2_ukr.h b/gtestsuite/testsuite/ukr/nrm2/test_nrm2_ukr.h new file mode 100644 index 0000000000..ea23732dbf --- /dev/null +++ b/gtestsuite/testsuite/ukr/nrm2/test_nrm2_ukr.h @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "util/nrm2/nrm2.h" +#include +#include "util/ref_nrm2.h" +#include "inc/check_error.h" + +// Defining the function pointer type for ?norm2fv vectorized kernels +// It is based on two template parameters : +// T : datatype of input vector x +// RT : datatype of output norm +template +using nrm2_ker_ft = void (*) + ( + dim_t n, + T* x, inc_t incx, + RT* norm, + cntx_t* cntx + ); + +// Function to test the ?norm2fv micro-kernels +// The function is templatized based on the datatype of the input and output operands. +// The first parameter(function pointer) uses these template parameters to take the appropriate type. +template +static void test_nrm2_ukr( nrm2_ker_ft ukr_fp, gtint_t n, gtint_t incx, double thresh, + bool is_memory_test = false) +{ + // Pointers to obtain the required memory. + T *x; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + + // Create the objects for the input and output operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + + // Acquire the first greenzone for x + x = ( T* )x_buffer.greenzone_1; + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + + RT norm = 0.0; + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( n, x, incx, &norm, NULL ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + + // copy data from 1st buffer of x to second buffer + memcpy( x, x_buffer.greenzone_1, size_x ); + + norm = 0.0; + ukr_fp( n, x, incx, &norm, NULL ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + RT norm_ref = testinghelpers::ref_nrm2( n, x, incx ); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + computediff( "norm", norm, norm_ref, thresh ); + +} + +// Test-case logger : Used to print the test-case details based on parameters +template ::real_type> +class nrm2UKRPrint { +public: + std::string operator()( + testing::TestParamInfo, gtint_t, gtint_t, bool>> str) const { + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + bool is_memory_test = std::get<3>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/scal2v/cscal2v_ukr.cpp b/gtestsuite/testsuite/ukr/scal2v/cscal2v_ukr.cpp new file mode 100644 index 0000000000..383a9f6085 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scal2v/cscal2v_ukr.cpp @@ -0,0 +1,159 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scal2v_ukr.h" + +class cscal2vGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cscal2vGeneric); + +// Tests using random integers as vector elements. +TEST_P( cscal2vGeneric, UKR ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + cscal2v_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conjx will be used: + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scal2v.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v_ukr( ukr, conjx, n, incx, incy, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_cscal2v_zen_int kernel. + The code structure for bli_cscal2v_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 16 --> L16 + Fringe loops : In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_cscal2v_zen_int_unitPositiveStride, + cscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_cscal2v_zen_int), + // conjx + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(16), // size n, for L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(79)), // 4*L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, -3.3}, scomplex{4.3,-2.1}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_cscal2v_zen_int_nonUnitPositiveStrides, + cscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_cscal2v_zen_int), + // conjx + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(scomplex{1.0, 0.0}, scomplex{-1.0, 0.0}, + scomplex{0.0, 1.0}, scomplex{0.0, -1.0}, + scomplex{0.0, -3.3}, scomplex{4.3,-2.1}, + scomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scal2v/dscal2v_ukr.cpp b/gtestsuite/testsuite/ukr/scal2v/dscal2v_ukr.cpp new file mode 100644 index 0000000000..346bf9c270 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scal2v/dscal2v_ukr.cpp @@ -0,0 +1,219 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scal2v_ukr.h" + +class dscal2vGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dscal2vGeneric); + +// Tests using random integers as vector elements. +TEST_P( dscal2vGeneric, UKR ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + dscal2v_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conjx will be used: + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scal2v.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v_ukr( ukr, conjx, n, incx, incy, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_dscal2v_zen_int kernel. + The code structure for bli_dscal2v_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 48 --> L48 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_dscal2v_zen_int_unitPositiveStride, + dscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_dscal2v_zen_int), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(96), // size n, for L48 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + gtint_t(128), // 2*L48 + L32 + gtint_t(127)), // 2*L48 + L16 + L8 + L4 + 3(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.3), double(-4.5), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_dscal2v_zen_int_nonUnitPositiveStrides, + dscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_dscal2v_zen_int), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.3), double(-4.5), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + +// ---------------------------------------------- +// ----- Begin ZEN4/5 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_dscal2v_zen_int_avx512 kernel. + The code structure for bli_dscal2v_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_dscal2v_zen_int_avx512_unitPositiveStride, + dscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_dscal2v_zen_int_avx512), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(191)), // 2*L64 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.3), double(-4.5), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_dscal2v_zen_int_avx512_nonUnitPositiveStrides, + dscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_dscal2v_zen_int_avx512), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(double(1.0), double(-1.0), + double(2.3), double(-4.5), + double(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4/5 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scal2v/sscal2v_ukr.cpp b/gtestsuite/testsuite/ukr/scal2v/sscal2v_ukr.cpp new file mode 100644 index 0000000000..3d13fec613 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scal2v/sscal2v_ukr.cpp @@ -0,0 +1,154 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scal2v_ukr.h" + +class sscal2vGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sscal2vGeneric); + +// Tests using random integers as vector elements. +TEST_P( sscal2vGeneric, UKR ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + sscal2v_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conjx will be used: + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scal2v.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v_ukr( ukr, conjx, n, incx, incy, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_sscal2v_zen_int kernel. + The code structure for bli_sscal2v_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 96 --> L96 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +INSTANTIATE_TEST_SUITE_P( + bli_sscal2v_zen_int_unitPositiveStride, + sscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_sscal2v_zen_int), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(96), // size n, for L96 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + gtint_t(256), // 2*L96 + L64 + gtint_t(255)), // 2*L96 + L32 + L16 + L8 + 7(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_sscal2v_zen_int_nonUnitPositiveStrides, + sscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_sscal2v_zen_int), + // conjx: uses n (no_conjugate) since it is real. + ::testing::Values('n'), + ::testing::Values(// Testing the loops standalone + gtint_t(7), // size n, for LScalar + gtint_t(15)), + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(float(1.0), float(-1.0), + float(2.3), float(-4.5), + float(0.0)), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scal2v/test_scal2v_ukr.h b/gtestsuite/testsuite/ukr/scal2v/test_scal2v_ukr.h new file mode 100644 index 0000000000..1d18ef308e --- /dev/null +++ b/gtestsuite/testsuite/ukr/scal2v/test_scal2v_ukr.h @@ -0,0 +1,149 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include + +#include "level1/scal2v/scal2v.h" +#include "level1/ref_scal2v.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Microkernel test body for scal2v operation. + */ +template +static void test_scal2v_ukr( FT ukr, char conjx, gtint_t n, gtint_t incx, gtint_t incy, + T alpha, double thresh, bool is_memory_test = false ) +{ + // Obtain and allocate memory for vectors. + T *x, *y, *y_ref; + + // Sizes of x and y vectors + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + // Create the object for the required operands + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // For y_ref, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for x and y. + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + + // There is no greenzone_2 for y_ref. + y_ref = ( T* )y_ref_buffer.greenzone_1; + + // Initialize x and y with random data. + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incy, y ); + + // Copying y to y_ref, for comparision after computation + memcpy( y_ref, y, size_y ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjx; + testinghelpers::char_to_blis_conj( conjx, &blis_conjx ); + + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Invoking BLIS ukr. + // This will check for out of bounds access within first redzone. + ukr( blis_conjx, n, &alpha, x, incx, y, incy, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone. + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_buffer.greenzone_1, size_x ); + memcpy( y, y_ref, size_y ); + + // Invoking BLIS ukr to check with the second redzone. + ukr( blis_conjx, n, &alpha, x, incx, y, incy, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + testinghelpers::ref_scal2v( conjx, n, alpha, x, incx, y_ref, incy ); + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "y", n, y, y_ref, incy, thresh ); +} + + +// Test-case logger : Used to print the test-case details based on parameters +template +class scal2vUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + gtint_t incy = std::get<4>(str.param); + T alpha = std::get<5>(str.param); + bool is_memory_test = std::get<6>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/scal2v/zscal2v_ukr.cpp b/gtestsuite/testsuite/ukr/scal2v/zscal2v_ukr.cpp new file mode 100644 index 0000000000..ca818c5501 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scal2v/zscal2v_ukr.cpp @@ -0,0 +1,164 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scal2v_ukr.h" + +class zscal2vGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscal2vGeneric); + +// Tests using random integers as vector elements. +TEST_P( zscal2vGeneric, UKR ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + zscal2v_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conjx will be used: + char conjx = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // stride size for y: + gtint_t incy = std::get<4>(GetParam()); + // alpha: + T alpha = std::get<5>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<6>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scal2v.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v_ukr( ukr, conjx, n, incx, incy, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zscal2v_zen_int kernel. + The code structure for bli_zscal2v_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : + Main loop : In blocks of 4 --> L4 + Fringe loops : In blocks of 2 --> L2 + Element-wise loop --> LScalar +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zscal2v_zen_int_unitPositiveStride, + zscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_zscal2v_zen_int), + // conjx + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(8), // size n, for L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + gtint_t(49)), // 4*L8 + L4 + L2 + 1(LScalar) + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zscal2v_zen_int_nonUnitPositiveStrides, + zscal2vGeneric, + ::testing::Combine( + ::testing::Values(bli_zscal2v_zen_int), + // conjx + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(// Testing the loops standalone + gtint_t(4), // size n, for L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + gtint_t(11)), // 2*L4 + L2 + 1(LScalar) + ::testing::Values(gtint_t(3), gtint_t(5)), // stride size for x + ::testing::Values(gtint_t(2), gtint_t(4)), // stride size for y + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{0.0, -1.0}, + dcomplex{0.0, -3.3}, dcomplex{4.3,-2.1}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(false, true) // is_memory_test + ), + (::scal2vUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scalv/cscalv_ukr.cpp b/gtestsuite/testsuite/ukr/scalv/cscalv_ukr.cpp new file mode 100644 index 0000000000..d802b47b00 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/cscalv_ukr.cpp @@ -0,0 +1,261 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scalv_ukr.h" + +class cscalvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( cscalvGeneric, UKR ) +{ + using T = scomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // denotes the kernel to be tested: + cscalv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // alpha: + T alpha = std::get<4>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv_ukr( ukr, conj_alpha, n, incx, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_cscalv_zen_int (AVX2) kernel. +/** + * Loops: + * L16 - Main loop, handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_cscalv_zen_int_unitPositiveStride, + cscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_cscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 3), // LScalar + gtint_t(32), // 2*L16 + gtint_t(47) // 2*L16 + L8 + L4 + 3(LScalar) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 0.0, 0.0}, + scomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_cscalv_zen_int_nonUnitPositiveStrides, + cscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_cscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 0.0, 0.0}, + scomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_cscalv_zen_int_avx512 (AVX512) kernel. +/** + * Loops: + * L96 - Main loop, handles 96 scomplex elements + * L64 - handles 64 scomplex elements + * L32 - handles 32 scomplex elements + * L16 - handles 16 scomplex elements + * L8 - handles 8 scomplex elements + * L4 - handles 4 scomplex elements + * LMasked - leftover loop + * + * LScalar - handles non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_cscalv_zen_int_avx512_unitPositiveStride, + cscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_cscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(285), // 2*L96 + L64 + L16 + L8 + L4 + LMasked + gtint_t(255), // 2*L96 + L32 + L16 + L8 + L4 + LMasked + gtint_t( 96), // L96 + gtint_t( 64), // L64 + gtint_t( 32), // L32 + gtint_t( 16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 3), // LMasked + gtint_t( 2), // LMasked + gtint_t( 1) // LMasked + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 0.0, 0.0}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_cscalv_zen_int_avx512_nonUnitPositiveStrides, + cscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_cscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + scomplex{-5.1, -7.3}, + scomplex{ 0.0, 0.0}, + scomplex{ 1.0, 1.0}, + scomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scalv/dscalv_ukr.cpp b/gtestsuite/testsuite/ukr/scalv/dscalv_ukr.cpp new file mode 100644 index 0000000000..e1d91e3570 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/dscalv_ukr.cpp @@ -0,0 +1,355 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scalv_ukr.h" + +class dscalvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( dscalvGeneric, UKR ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + dscalv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // alpha: + T alpha = std::get<4>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv_ukr( ukr, conj_alpha, n, incx, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_dscalv_zen_int (AVX2) kernel. +/** + * Loops: + * L16 - Main loop, handles 16 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int_unitPositiveStride, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(32), // L16 (executed twice) + gtint_t(17), // L16 + Ln_left + gtint_t(16), // L16 + gtint_t( 1) // LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + // @note: disabling alpha = 0 test for bli_dscalv_zen_int. + // Segmentation Fault is being observed for alpha = 0 since the + // kernel isn't handling the condition where cntx = NULL. + // double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int_nonUnitPositiveStrides, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + // @note: disabling alpha = 0 test for bli_dscalv_zen_int. + // Segmentation Fault is being observed for alpha = 0 since the + // kernel isn't handling the condition where cntx = NULL. + // double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +// Tests for bli_dscalv_zen_int10 (AVX2) kernel. +/** + * Cases and Loops: + * C0 L64 - Main loop, handles 64 elements + * C0 L48 - handles 48 elements + * C1 L32 - handles 32 elements + * C2 L12 - handles 12 elements + * C2 L4 - handles 4 elements + * C2 LScalar - leftover loop + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int10_unitPositiveStride, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int10), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // testing case 0 (n > 500) + gtint_t(512), // C0 L64 + gtint_t(560), // C0 + gtint_t(544), // C0 L0 + C1 + gtint_t(572), // C0 + C2 (L12) + gtint_t(564), // C0 + C2 (L4) + gtint_t(573), // C0 + C2 (L12 + LScalar) + gtint_t(565), // C0 + C2 (L4 + LScalar) + gtint_t(561), // C0 + C2 (LScalar) + gtint_t(556), // C0 L64 + C1 + C2 (L12) + gtint_t(557), // C0 L64 + C1 + C2 (L12 + LScalar) + gtint_t(548), // C0 L64 + C1 + C2 (L4) + gtint_t(549), // C0 L64 + C1 + C2 (L4 + LScalar) + + // testing case 1 (200 < n < 500) + gtint_t(224), // C1 + gtint_t(236), // C1 + C2 (L12) + gtint_t(240), // C1 + C2 (L12 + L4) + gtint_t(241), // C1 + C2 (L12 + L4 + LScalar) + + // testing case 2 (n < 200) + gtint_t(12), // C2 (L12) + gtint_t(16), // C2 (L12 + L4) + gtint_t(17) // C2 (L12 + L4 + LScalar) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int10_nonUnitPositiveStrides, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int10), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_dscalv_zen_int_avx512 (AVX512) kernel. +/** + * Loops: + * L64 - Main loop, handles 64 elements + * L32 - handles 32 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int_avx512_unitPositiveStride, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // testing each loop individually + gtint_t(128), // L64 (executed twice) + gtint_t( 64), // L64 + gtint_t( 32), // L32 + gtint_t( 16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // L2 + gtint_t( 1), // LScalar + + // testing all loops from top to bottom + gtint_t(123), // L64 to LScalar + gtint_t(126), // L64 to L2 + gtint_t(124), // L64 to L4 + gtint_t(120), // L64 to L8 + gtint_t(112), // L64 to L16 + gtint_t( 96), // L64 to L32 + + gtint_t( 63), // L32 to LScalar + gtint_t( 62), // L32 to L2 + gtint_t( 60), // L32 to L4 + gtint_t( 56), // L32 to L8 + gtint_t( 48), // L32 to L16 + + gtint_t( 31), // L16 - LScalar + gtint_t( 30), // L16 - L2 + gtint_t( 28), // L16 - L4 + gtint_t( 24), // L16 - L8 + + gtint_t( 15), // L8 to LScalar + gtint_t( 14), // L8 to L2 + gtint_t( 12), // L8 to L4 + + gtint_t( 7), // L4 to LScalar + gtint_t( 6), // L4 to L2 + + gtint_t( 3) // L2 to LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_dscalv_zen_int_avx512_nonUnitPositiveStrides, + dscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_dscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + double( 0.0), + double( 7.0), + double(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scalv/sscalv_ukr.cpp b/gtestsuite/testsuite/ukr/scalv/sscalv_ukr.cpp new file mode 100644 index 0000000000..d92a1d7093 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/sscalv_ukr.cpp @@ -0,0 +1,243 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scalv_ukr.h" + +class sscalvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( sscalvGeneric, UKR ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + sscalv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // alpha: + T alpha = std::get<4>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + float thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv_ukr( ukr, conj_alpha, n, incx, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_sscalv_zen_int (AVX2) kernel. +/** + * Loops: + * L32 - Main loop, handles 32 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_sscalv_zen_int_unitPositiveStride, + sscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_sscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(32), // L32 + gtint_t(15), // LScalar + gtint_t(96), // 3*L32 + gtint_t(111) // 3*L32 + 15(LScalar) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + // @note: disabling alpha = 0 test for bli_sscalv_zen_int. + // Segmentation Fault is being observed for alpha = 0 since the + // kernel isn't handling the condition where cntx = NULL. + // float( 0.0), + float( 7.0), + float(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_sscalv_zen_int_nonUnitPositiveStrides, + sscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_sscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + // @note: disabling alpha = 0 test for bli_sscalv_zen_int. + // Segmentation Fault is being observed for alpha = 0 since the + // kernel isn't handling the condition where cntx = NULL. + // float( 0.0), + float( 7.0), + float(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +// Tests for bli_sscalv_zen_int10 (AVX2) kernel. +/** + * Cases and Loops: + * C0 L128 - Main loop, handles 128 elements + * C0 L96 - handles 96 elements + * C1 L48 - handles 48 elements + * C2 L24 - handles 24 elements + * C2 L8 - handles 8 elements + * C2 LScalar - leftover loop + * + * The switch cases are cascading, and the order + * is C0 --> C1 --> C2 + * + * LNUnit - loop for non-unit increments +*/ +INSTANTIATE_TEST_SUITE_P( + bli_sscalv_zen_int10_unitPositiveStride, + sscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_sscalv_zen_int10), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + // testing case 0 (n >= 500) + gtint_t(512), // C0 4*L128 + gtint_t(608), // C0 4*L128 + C1 L96 + gtint_t(599), // C0 4*L128 + C2 (L48 + L24 + L8 + 7(LSscalar)) + gtint_t(623), // C0 4*L128 + C1 L96 + C2 (L8 + 7(LScalar)) + + // testing case 1 (300 <= n < 500) + gtint_t(384), // C1 4*L96 + gtint_t(432), // C1 4*L96 + C2 L48 + gtint_t(456), // C1 4*L96 + C2 (L48 + L24) + gtint_t(464), // C1 4*L96 + C2 (L48 + L24 + L8) + gtint_t(471), // C1 4*L96 + C2 (L48 + L24 + L8 + 7(LScalar)) + + // testing case 2 (n < 300) + gtint_t(192), // C2 4*L48 + gtint_t(216), // C2 (4*L48 + L24) + gtint_t(224), // C2 (4*L48 + L24 + L8) + gtint_t(231) // C2 (4*L48 + L24 + L8 + 7(LScalar)) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + float( 0.0), + float( 7.0), + float(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_sscalv_zen_int10_nonUnitPositiveStrides, + sscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_sscalv_zen_int10), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n'), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + float( 0.0), + float( 7.0), + float(-3.0) + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scalv/test_scalv_ukr.h b/gtestsuite/testsuite/ukr/scalv/test_scalv_ukr.h new file mode 100644 index 0000000000..62e2e754fb --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/test_scalv_ukr.h @@ -0,0 +1,140 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include + +#include "level1/scalv/scalv.h" +#include "level1/ref_scalv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Microkernel test body for scalv operation. + */ +template +static void test_scalv_ukr( FT ukr, char conja_alpha, gtint_t n, gtint_t incx, + T alpha, double thresh, bool is_memory_test = false ) +{ + // Obtain and allocate memory for vectors. + T *x, *x_ref; + + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + + // is_memory_test = false for x_ref since we don't require different green + // or red zones. + testinghelpers::ProtectedBuffer x_ref_buffer( size_x, false, false ); + + // Acquire the first set of greenzones for x. + x = ( T* )x_buffer.greenzone_1; + // There is no greenzone_2 for x_ref. + x_ref = ( T* )x_ref_buffer.greenzone_1; + + // Initialize x with random data. + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + + // Copying x to x_ref, for comparision after computation + memcpy( x_ref, x, size_x ); + + // Char conjx to BLIS conjx conversion + conj_t blis_conjalpha; + testinghelpers::char_to_blis_conj( conja_alpha, &blis_conjalpha ); + + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Invoking BLIS ukr. + // This will check for out of bounds access within first redzone. + ukr( blis_conjalpha, n, &alpha, x, incx, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone. + x = ( T* )x_buffer.greenzone_2; + + // Copy the data for x accordingly + memcpy( x, x_ref, size_x ); + + // Invoking BLIS ukr to check with the second redzone. + ukr( blis_conjalpha, n, &alpha, x, incx, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Invoking the reference implementation to get reference results. + if constexpr ( testinghelpers::type_info::is_complex && + testinghelpers::type_info::is_real ) + testinghelpers::ref_scalv( conja_alpha, n, alpha.real, x_ref, incx ); + else // if constexpr ( std::is_same::value ) + testinghelpers::ref_scalv( conja_alpha, n, alpha, x_ref, incx ); + + // Compute component-wise error. + computediff( "x", n, x, x_ref, incx, thresh ); +} + + +// Test-case logger : Used to print the test-case details based on parameters +template +class scalvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjx = std::get<1>(str.param); + gtint_t n = std::get<2>(str.param); + gtint_t incx = std::get<3>(str.param); + T1 alpha = std::get<4>(str.param); + bool is_memory_test = std::get<5>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_conjx_" + std::string(&conjx, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/scalv/zdscalv_ukr.cpp b/gtestsuite/testsuite/ukr/scalv/zdscalv_ukr.cpp new file mode 100644 index 0000000000..7f8f964725 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/zdscalv_ukr.cpp @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scalv_ukr.h" + +class zdscalvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zdscalvGeneric, UKR ) +{ + using T = dcomplex; + using U = double; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // denotes the kernel to be tested: + zscalv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // alpha: + T alpha = std::get<4>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv_ukr( ukr, conj_alpha, n, incx, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_zdscalv_zen_int10 (AVX2) kernel. +/** + * Loops: + * L30 - Main loop, handles 30 elements + * L24 - handles 24 elements + * L16 - handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zdscalv_zen_int10_unitPositiveStride, + zdscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdscalv_zen_int10), + // conj(alpha): specify if alpha needs to be conjugated. + ::testing::Values( + 'n', + 'c' + ), + // m: size of vector. + ::testing::Values( + gtint_t(75), // L30x2, L8 upto LScalar + gtint_t(49), // L30, L16, L4, L2, LScalar + gtint_t(29), // L24, L4, LScalar + gtint_t(23), // L16 upto LScalar + gtint_t(30), // L30 + gtint_t(24), // L24 + gtint_t(16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // L2 + gtint_t( 1) // LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, // ZDSCAL is expected to return early for unit alpha. + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zdscalv_zen_int10_nonUnitPositiveStride, + zdscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdscalv_zen_int10), + // conj(alpha): specify if alpha needs to be conjugated. + ::testing::Values( + 'n', + 'c' + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, // ZDSCAL is expected to return early for unit alpha. + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_zdscalv_zen_int_avx512 (AVX512) kernel. +/** + * Loops: + * L16 - Main loop, handles 16 elements + * L8 - handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zdscalv_zen_int_avx512_unitPositiveStride, + zdscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdscalv_zen_int_avx512), + // conj(alpha): specify if alpha needs to be conjugated. + ::testing::Values( + 'n', + 'c' + ), + // m: size of vector. + ::testing::Values( + gtint_t(47), // L16x2 upto LScalar + gtint_t(16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // L2 + gtint_t( 1) // LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, // ZDSCAL is expected to return early for unit alpha. + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zdscalv_zen_int_avx512_nonUnitPositiveStrides, + zdscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zdscalv_zen_int_avx512), + // conj(alpha): specify if alpha needs to be conjugated. + ::testing::Values( + 'n', + 'c' + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{-1.0, -1.0}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 1.0, 1.0}, // ZDSCAL is expected to return early for unit alpha. + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- diff --git a/gtestsuite/testsuite/ukr/scalv/zscalv_ukr.cpp b/gtestsuite/testsuite/ukr/scalv/zscalv_ukr.cpp new file mode 100644 index 0000000000..ade45336b4 --- /dev/null +++ b/gtestsuite/testsuite/ukr/scalv/zscalv_ukr.cpp @@ -0,0 +1,255 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_scalv_ukr.h" + +class zscalvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscalvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zscalvGeneric, UKR ) +{ + using T = dcomplex; + + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + + // denotes the kernel to be tested: + zscalv_ker_ft ukr = std::get<0>(GetParam()); + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<1>(GetParam()); + // vector length: + gtint_t n = std::get<2>(GetParam()); + // stride size for x: + gtint_t incx = std::get<3>(GetParam()); + // alpha: + T alpha = std::get<4>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite scalv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else if (alpha == testinghelpers::ZERO() || alpha == testinghelpers::ONE()) + thresh = 0.0; + else + thresh = testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv_ukr( ukr, conj_alpha, n, incx, alpha, thresh, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +// Tests for bli_zscalv_zen_int (AVX2) kernel. +/** + * Loops: + * L8 - Main loop, handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zscalv_zen_int_unitPositiveStride, + zscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(16), // L8 (executed twice) + gtint_t(15), // L8 upto LScalar + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // L2 + gtint_t( 1) // LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zscalv_zen_int_nonUnitPositiveStrides, + zscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zscalv_zen_int), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif +// ---------------------------------------------- +// ----- End ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- + + +// ---------------------------------------------- +// ----- Begin ZEN4 (AVX512) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +// Tests for bli_zscalv_zen_int_avx512 (AVX512) kernel. +/** + * Loops: + * L48 - Main loop, handles 48 elements + * L32 - Main loop, handles 32 elements + * L16 - Main loop, handles 16 elements + * L8 - Main loop, handles 8 elements + * L4 - handles 4 elements + * L2 - handles 2 elements + * LScalar - leftover loop (also handles non-unit increments) +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zscalv_zen_int_avx512_unitPositiveStride, + zscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(143), // L48 x2 + L32 + L8 + L4 + L2 + LScalar + gtint_t(127), // L48 x2 + L16 + L8 + L4 + L2 + LScalar + gtint_t(48), // L48 + gtint_t(47), // L32 + L16 + L8 + L4 + L2 + LScalar + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t( 8), // L8 + gtint_t( 4), // L4 + gtint_t( 2), // L2 + gtint_t( 1) // LScalar + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) // unit stride + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); + +INSTANTIATE_TEST_SUITE_P( + bli_zscalv_zen_int_avx512_nonUnitPositiveStrides, + zscalvGeneric, + ::testing::Combine( + ::testing::Values(bli_zscalv_zen_int_avx512), + // conj(alpha): uses n (no_conjugate) since it is real. + ::testing::Values('n' +#ifdef TEST_BLIS_TYPED + , 'c' // conjx +#endif + ), + // m: size of vector. + ::testing::Values( + gtint_t(3), gtint_t(30), gtint_t(112) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(3), gtint_t(7) // few non-unit strides for sanity check + ), + // alpha: value of scalar. + ::testing::Values( + dcomplex{-5.1, -7.3}, + dcomplex{ 0.0, 0.0}, + dcomplex{ 7.3, 5.1} + ), + ::testing::Values(false, true) // is_memory_test + ), + (::scalvUKRPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/setv/csetv_ukr.cpp b/gtestsuite/testsuite/ukr/setv/csetv_ukr.cpp new file mode 100644 index 0000000000..6aec8ad414 --- /dev/null +++ b/gtestsuite/testsuite/ukr/setv/csetv_ukr.cpp @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_setv_ukr.h" + +using T = scomplex; +using FT = csetv_ker_ft; + +class csetvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csetvGeneric); + +// Tests using random integers as vector elements. +TEST_P( csetvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjalpha + char conjalpha = std::get<1>(GetParam()); + // denotes alpha + T alpha = std::get<2>(GetParam()); + // vector length + gtint_t n = std::get<3>(GetParam()); + // stride size for x + gtint_t incx = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_setv_ukr( ukr_fp, conjalpha, alpha, n, incx, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_csetv_zen_int kernel. + The code structure for bli_csetv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_csetv_zen_int_unitStrides, + csetvGeneric, + ::testing::Combine( + ::testing::Values(bli_csetv_zen_int), + ::testing::Values('n' // conjalpha +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(scomplex{2.2, -1.8}), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(64), // for size n, L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + // 5*L64 + gtint_t(320), + // 5*L64 + L32 + gtint_t(352), + // 5*L64 + L32 + L16 + gtint_t(368), + // 5*L64 + L32 + L16 + L8 + gtint_t(376), + // 5*L64 + L32 + L16 + L8 + L4 + gtint_t(380), + // 5*L64 + L32 + L16 + L8 + L4 + 3(LScalar) + gtint_t(383)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_csetv_zen_int_nonUnitStrides, + csetvGeneric, + ::testing::Combine( + ::testing::Values(bli_csetv_zen_int), + ::testing::Values('n' // conjalpha +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(scomplex{2.2, -1.8}), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/setv/dsetv_ukr.cpp b/gtestsuite/testsuite/ukr/setv/dsetv_ukr.cpp new file mode 100644 index 0000000000..363498eff6 --- /dev/null +++ b/gtestsuite/testsuite/ukr/setv/dsetv_ukr.cpp @@ -0,0 +1,210 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_setv_ukr.h" + +using T = double; +using FT = dsetv_ker_ft; + +class dsetvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsetvGeneric); + +// Tests using random integers as vector elements. +TEST_P( dsetvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjalpha + char conjalpha = std::get<1>(GetParam()); + // denotes alpha + T alpha = std::get<2>(GetParam()); + // vector length + gtint_t n = std::get<3>(GetParam()); + // stride size for x + gtint_t incx = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_setv_ukr( ukr_fp, conjalpha, alpha, n, incx, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_dsetv_zen_int kernel. + The code structure for bli_dsetv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 64 --> L64 + Fringe loops : In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dsetv_zen_int_unitStrides, + dsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_dsetv_zen_int), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(double(2.2)), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(64), // size n, for L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + // 5*L64 + gtint_t(320), + // 5*L64 + L32 + gtint_t(352), + // 5*L64 + L32 + L16 + gtint_t(368), + // 5*L64 + L32 + L16 + L8 + gtint_t(376), + // 5*L64 + L32 + L16 + L8 + L4 + gtint_t(380), + // 5*L64 + L32 + L16 + L8 + L4 + 3(LScalar) + gtint_t(383)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dsetv_zen_int_nonUnitStrides, + dsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_dsetv_zen_int), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(double(2.2)), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_dsetv_zen_int_avx512 kernel. + The code structure for bli_dsetv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 256 --> L256 + Fringe loops : In blocks of 128 --> L128 + In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + Masked loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dsetv_zen_int_avx512_unitStrides, + dsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_dsetv_zen_int_avx512), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(double(2.2)), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(256), // size n, for L256 + gtint_t(128), // L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(3), // LScalar + // Testing the loops with combinations + // 2*L256 + gtint_t(512), + // 2*L256 + L128 + gtint_t(640), + // 2*L256 + L128 + L64 + gtint_t(704), + // 2*L256 + L128 + L64 + L32 + gtint_t(736), + // 2*L256 + L128 + L64 + L32 + L16 + gtint_t(752), + // 2*L256 + L128 + L64 + L32 + L16 + L8 + gtint_t(760), + // 2*L256 + L128 + L64 + L32 + L16 + L8 + L4 + gtint_t(764), + // 2*L256 + L128 + L64 + L32 + L16 + L8 + L4 + LScalar + gtint_t(767)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with Non-Unit Strides(US), across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_dsetv_zen_int_avx512_nonUnitStrides, + dsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_dsetv_zen_int_avx512), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(double(2.2)), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/setv/ssetv_ukr.cpp b/gtestsuite/testsuite/ukr/setv/ssetv_ukr.cpp new file mode 100644 index 0000000000..823f62b4d3 --- /dev/null +++ b/gtestsuite/testsuite/ukr/setv/ssetv_ukr.cpp @@ -0,0 +1,206 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_setv_ukr.h" + +using T = float; +using FT = ssetv_ker_ft; + +class ssetvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssetvGeneric); + +// Tests using random integers as vector elements. +TEST_P( ssetvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjalpha + char conjalpha = std::get<1>(GetParam()); + // denotes alpha + T alpha = std::get<2>(GetParam()); + // vector length + gtint_t n = std::get<3>(GetParam()); + // stride size for x + gtint_t incx = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_setv_ukr( ukr_fp, conjalpha, alpha, n, incx, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_ssetv_zen_int kernel. + The code structure for bli_ssetv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ssetv_zen_int_unitStrides, + ssetvGeneric, + ::testing::Combine( + ::testing::Values(bli_ssetv_zen_int), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(float(1.2)), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(128), // for size n, L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(7), // LScalar + // Testing the loops with combinations + // 2*L128 + gtint_t(256), + // 2*L128 + L64 + gtint_t(320), + // 2*L128 + L64 + L32 + gtint_t(352), + // 2*L128 + L64 + L32 + L16 + gtint_t(368), + // 2*L128 + L64 + L32 + L16 + L8 + gtint_t(376), + // 2*L128 + L64 + L32 + L16 + L8 + LScalar + gtint_t(383)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ssetv_zen_int_nonUnitStrides, + ssetvGeneric, + ::testing::Combine( + ::testing::Values(bli_ssetv_zen_int), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(float(1.2)), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_ssetv_zen_int_avx512 kernel. + The code structure for bli_ssetv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 512 --> L512 + Fringe loops : In blocks of 256 --> L256 + In blocks of 128 --> L128 + In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + Masked loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ssetv_zen_int_avx512_unitStrides, + ssetvGeneric, + ::testing::Combine( + ::testing::Values(bli_ssetv_zen_int_avx512), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(float(1.2)), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(512), // for size n, L512 + gtint_t(256), // L64 + gtint_t(128), // L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(15), // LScalar + // Testing the loops with combinations + // 2*L512 + gtint_t(1024), + // 2*L512 + L256 + gtint_t(1280), + // 2*L512 + L256 + L128 + gtint_t(1408), + // 2*L512 + L256 + L128 + L64 + gtint_t(1472), + // 2*L512 + L256 + L128 + L64 + L32 + gtint_t(1504), + // 2*L512 + L256 + L128 + L64 + L32 + L16 + gtint_t(1520), + // 2*L512 + L256 + L128 + L64 + L32 + L16 + LScalar + gtint_t(1535)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_ssetv_zen_int_avx512_nonUnitStrides, + ssetvGeneric, + ::testing::Combine( + ::testing::Values(bli_ssetv_zen_int_avx512), + ::testing::Values('n', 'c'), // conjalpha + ::testing::Values(float(1.2)), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/setv/test_setv_ukr.h b/gtestsuite/testsuite/ukr/setv/test_setv_ukr.h new file mode 100644 index 0000000000..d60d829b51 --- /dev/null +++ b/gtestsuite/testsuite/ukr/setv/test_setv_ukr.h @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "level1/setv/setv.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" + +/** + * @brief Generic test body for copyv operation. + */ + +template +void test_setv_ukr( FT ukr_fp, char conjalpha, T alpha, gtint_t n, gtint_t incx, bool is_memory_test = false ) +{ + // Pointers to obtain the required memory. + T *x, *x_copy; + // Copying alpha to a local variable, since we pass by reference to kernel + T alpha_copy = alpha; + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + + // Create the object for the required operand + // The kernel does not expect the memory to be aligned + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + + // For x_copy, we don't need different greenzones and any redzone. + // Thus, we pass is_memory_test as false + testinghelpers::ProtectedBuffer x_copy_buffer( size_x, false, false ); + + // Acquire the first greenzone for x + x = ( T* )x_buffer.greenzone_1; + x_copy = ( T* )x_copy_buffer.greenzone_1; // For x_copy, there is no greenzone_2 + + // Initialize the memory with random data + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x ); + + // Copying the contents of y to y_ref + memcpy( x_copy, x, size_x ); + + // Char conjalpha to BLIS conjalpha conversion + conj_t blis_conjalpha; + testinghelpers::char_to_blis_conj( conjalpha, &blis_conjalpha ); + + // Add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // Call the ukr function. + // This call is made irrespective of is_memory_test. + // This will check for out of bounds access with first redzone(if memory test is true) + // Else, it will just call the ukr function. + ukr_fp( blis_conjalpha, n, &alpha, x, incx, nullptr ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone + x = ( T* )x_buffer.greenzone_2; + + // Copy the data for x accordingly + memcpy( x, x_copy, size_x ); + + alpha = alpha_copy; + + // Call the ukr function, to check with the second redzone. + ukr_fp( blis_conjalpha, n, &alpha, x, incx, nullptr ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + T alpha_ref = alpha_copy; +#ifdef TEST_BLIS_TYPED + if( testinghelpers::chkconj( conjalpha ) ) + { + alpha_ref = testinghelpers::conj( alpha_copy ); + } +#endif + + //---------------------------------------------------------- + // Reference computation + //---------------------------------------------------------- + gtint_t i, idx; + for( idx = 0 ; idx < n ; idx++ ) + { + i = (incx > 0) ? (idx * incx) : ( - ( n - idx - 1 ) * incx ); + x_copy[i] = alpha_ref; + } + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( "x", n, x, x_copy, incx ); +} + +// Test-case logger : Used to print the test-case details for unit testing the kernels. +// NOTE : The kernel name is the prefix in instantiator name, and thus is not printed +// with this logger. +template +class setvUkrPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char conjalpha = std::get<1>(str.param); + T alpha = std::get<2>(str.param); + gtint_t n = std::get<3>(str.param); + gtint_t incx = std::get<4>(str.param); + bool is_memory_test = std::get<5>(str.param); + + std::string str_name = ""; + str_name += "_n_" + std::to_string(n); + str_name += "_conjalpha_" + std::string(&conjalpha, 1); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/setv/zsetv_ukr.cpp b/gtestsuite/testsuite/ukr/setv/zsetv_ukr.cpp new file mode 100644 index 0000000000..d0a97bfc98 --- /dev/null +++ b/gtestsuite/testsuite/ukr/setv/zsetv_ukr.cpp @@ -0,0 +1,226 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_setv_ukr.h" + +using T = dcomplex; +using FT = zsetv_ker_ft; + +class zsetvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsetvGeneric); + +// Tests using random integers as vector elements. +TEST_P( zsetvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + FT ukr_fp = std::get<0>(GetParam()); + // denotes conjalpha + char conjalpha = std::get<1>(GetParam()); + // denotes alpha + T alpha = std::get<2>(GetParam()); + // vector length + gtint_t n = std::get<3>(GetParam()); + // stride size for x + gtint_t incx = std::get<4>(GetParam()); + // is_memory_test + bool is_memory_test = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_setv_ukr( ukr_fp, conjalpha, alpha, n, incx, is_memory_test ); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +/* + Unit testing for functionality of bli_zsetv_zen_int kernel. + The code structure for bli_zsetv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 32 --> L32 + Fringe loops : In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Element-wise loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zsetv_zen_int_unitStrides, + zsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_zsetv_zen_int), + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(dcomplex{2.2, -1.8}), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(32), // for size n, L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + // Testing the loops with combinations + // 5*L32 + gtint_t(160), + // 5*L32 + L16 + gtint_t(176), + // 5*L32 + L16 + L8 + gtint_t(184), + // 5*L32 + L16 + L8 + L4 + gtint_t(188), + // 5*L32 + L16 + L8 + L4 + L2 + gtint_t(190), + // 5*L32 + L16 + L8 + L4 + L2 + 1(LScalar) + gtint_t(191)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zsetv_zen_int_nonUnitStrides, + zsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_zsetv_zen_int), + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(dcomplex{2.2, -1.8}), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +/* + Unit testing for functionality of bli_zsetv_zen_int_avx512 kernel. + The code structure for bli_zsetv_zen_int_avx512( ... ) is as follows : + For unit strides : + Main loop : In blocks of 128 --> L128 + Fringe loops : In blocks of 64 --> L64 + In blocks of 32 --> L32 + In blocks of 16 --> L16 + In blocks of 8 --> L8 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + Masked loop --> LScalar + + For non-unit strides : A single loop, to process element wise. +*/ +// Unit testing with unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zsetv_zen_int_avx512_unitStrides, + zsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_zsetv_zen_int_avx512), + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(dcomplex{2.2, -1.8}), // alpha + ::testing::Values(// Testing the loops standalone + gtint_t(128), // for size n, L128 + gtint_t(64), // L64 + gtint_t(32), // L32 + gtint_t(16), // L16 + gtint_t(8), // L8 + gtint_t(4), // L4 + gtint_t(2), // L2 + gtint_t(1), // LScalar + // Testing the loops with combinations + // 2*L128 + gtint_t(256), + // 2*L128 + L64 + gtint_t(320), + // 2*L128 + L64 + L32 + gtint_t(352), + // 2*L128 + L64 + L32 + L16 + gtint_t(368), + // 2*L128 + L64 + L32 + L16 + L8 + gtint_t(376), + // 2*L128 + L64 + L32 + L16 + L8 + L4 + gtint_t(380), + // 2*L128 + L64 + L32 + L16 + L8 + L4 + L2 + gtint_t(382), + // 2*L128 + L64 + L32 + L16 + L8 + L4 + L2 + LScalar + gtint_t(383)), + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); + +// Unit testing with non-unit strides, across all loops. +INSTANTIATE_TEST_SUITE_P( + bli_zsetv_zen_int_avx512_nonUnitStrides, + zsetvGeneric, + ::testing::Combine( + ::testing::Values(bli_zsetv_zen_int_avx512), + ::testing::Values('n' // conjx +#ifdef TEST_BLIS_TYPED + , 'c' +#endif + ), + ::testing::Values(dcomplex{2.2, -1.8}), // alpha + ::testing::Values(gtint_t(25), gtint_t(37)), // size of the vector + ::testing::Values(gtint_t(5)), // stride size for x + ::testing::Values(false, true) // is_memory_test + ), + (::setvUkrPrint()) + ); +#endif diff --git a/gtestsuite/testsuite/ukr/swapv/dswapv_ukr.cpp b/gtestsuite/testsuite/ukr/swapv/dswapv_ukr.cpp new file mode 100644 index 0000000000..95ed3868f0 --- /dev/null +++ b/gtestsuite/testsuite/ukr/swapv/dswapv_ukr.cpp @@ -0,0 +1,132 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv_ukr.h" + +class dswapvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dswapvGeneric); + +TEST_P( dswapvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + dswapv_ker_ft ukr = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<4>(GetParam()); + + using T = double; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv_ukr( ukr, n, incx, incy, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) + +// Tests for bli_dswapv_zen_int8 (AVX2) kernel. +// For unit inc on x and y: +// Optimised code is avialble for n = 32, 16, 8, 4 + +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + dswapvGeneric, + ::testing::Combine( + ::testing::Values(bli_dswapv_zen_int8), + // n: size of vector. + ::testing::Values( + gtint_t(1), gtint_t(2), gtint_t(4), gtint_t(8), gtint_t(16), gtint_t(32), + gtint_t(64), gtint_t(128), gtint_t(5), gtint_t(9), gtint_t(17), gtint_t(33), + gtint_t(65), gtint_t(129), gtint_t(6), gtint_t(10), gtint_t(18), gtint_t(34), + gtint_t(68), gtint_t(130), gtint_t(12), gtint_t(24), gtint_t(40), gtint_t(72), + gtint_t(136), gtint_t(20), gtint_t(36), gtint_t(96), gtint_t(160) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ), + // is_memory_test + ::testing::Values(false, true) + ), + ::swapvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + dswapvGeneric, + ::testing::Combine( + ::testing::Values(bli_dswapv_zen_int8), + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(500) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(500) + ), + // is_memory_test + ::testing::Values(false, true) + ), + ::swapvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/swapv/sswapv_ukr.cpp b/gtestsuite/testsuite/ukr/swapv/sswapv_ukr.cpp new file mode 100644 index 0000000000..4d1a5a9b6f --- /dev/null +++ b/gtestsuite/testsuite/ukr/swapv/sswapv_ukr.cpp @@ -0,0 +1,132 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_swapv_ukr.h" + +class sswapvGeneric : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sswapvGeneric); + +TEST_P( sswapvGeneric, UKR ) +{ + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes the kernel to be tested: + sswapv_ker_ft ukr = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // is_memory_test: + bool is_memory_test = std::get<4>(GetParam()); + + using T = float; + + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_swapv_ukr( ukr, n, incx, incy, is_memory_test ); +} + +// ---------------------------------------------- +// ----- Begin ZEN1/2/3 (AVX2) Kernel Tests ----- +// ---------------------------------------------- +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) + +// Tests for bli_dswapv_zen_int8 (AVX2) kernel. +// For unit inc on x and y: +// When n values are 64, 32, 16, 8, 4 it is avx2 optimised + +INSTANTIATE_TEST_SUITE_P( + UnitIncrements, + sswapvGeneric, + ::testing::Combine( + ::testing::Values(bli_sswapv_zen_int8), + // n: size of vector. + ::testing::Values( + gtint_t(1), gtint_t(2), gtint_t(8), gtint_t(16), gtint_t(32), + gtint_t(64), gtint_t(128), gtint_t(9), gtint_t(17), gtint_t(33), + gtint_t(65), gtint_t(129), gtint_t(10), gtint_t(18), gtint_t(34), + gtint_t(68), gtint_t(130), gtint_t(24), gtint_t(40), gtint_t(72), + gtint_t(136), gtint_t(96), gtint_t(160) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(1) + ), + // is_memory_test + ::testing::Values(false, true) + ), + ::swapvUKRPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + NonUnitIncrements, + sswapvGeneric, + ::testing::Combine( + ::testing::Values(bli_sswapv_zen_int8), + // n: size of vector. + ::testing::Values( + gtint_t(1), + gtint_t(9), + gtint_t(55) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(500) + ), + // incy: stride of y vector. + ::testing::Values( + gtint_t(500) + ), + // is_memory_test + ::testing::Values(false, true) + ), + ::swapvUKRPrint() + ); +#endif diff --git a/gtestsuite/testsuite/ukr/swapv/test_swapv_ukr.h b/gtestsuite/testsuite/ukr/swapv/test_swapv_ukr.h new file mode 100644 index 0000000000..530d626ea9 --- /dev/null +++ b/gtestsuite/testsuite/ukr/swapv/test_swapv_ukr.h @@ -0,0 +1,136 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include + +#include "level1/swapv/swapv.h" +#include "inc/check_error.h" + +/** + * @brief Microkernel test body for swapv operation. + */ +template +static void test_swapv_ukr( FT ukr, gtint_t n, gtint_t incx, gtint_t incy, + bool is_memory_test = false ) +{ + // Obtain and allocate memory for vectors. + T *x, *y, *x_ref, *y_ref; + + gtint_t size_x = testinghelpers::buff_dim( n, incx ) * sizeof( T ); + gtint_t size_y = testinghelpers::buff_dim( n, incy ) * sizeof( T ); + + testinghelpers::ProtectedBuffer x_buffer( size_x, false, is_memory_test ); + testinghelpers::ProtectedBuffer y_buffer( size_y, false, is_memory_test ); + + // is_memory_test = false for x_ref & y_ref since we don't require + // different green or red zones. + testinghelpers::ProtectedBuffer x_ref_buffer( size_x, false, false ); + testinghelpers::ProtectedBuffer y_ref_buffer( size_y, false, false ); + + // Acquire the first set of greenzones for x. + x = ( T* )x_buffer.greenzone_1; + y = ( T* )y_buffer.greenzone_1; + + // There is no greenzone_2 for x_ref & y_ref + x_ref = ( T* )x_ref_buffer.greenzone_1; + y_ref = ( T* )y_ref_buffer.greenzone_1; + + // Initialize x with random data. + testinghelpers::datagenerators::randomgenerators( -100, 100, n, incx, x ); + testinghelpers::datagenerators::randomgenerators( 110, 200, n, incy, y ); + + // Copying x to x_ref & y to y_ref, for comparision after computation + memcpy( x_ref, x, size_x ); + memcpy( y_ref, y, size_y ); + + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // This will check for out of bounds access within first redzone. + swapv( n, x, incx, y, incy ); + + if ( is_memory_test ) + { + // Acquire the pointers near the second redzone. + x = ( T* )x_buffer.greenzone_2; + y = ( T* )y_buffer.greenzone_2; + + // Copy the data for x and y accordingly + memcpy( x, x_ref, size_x ); + memcpy( y, y_ref, size_y ); + + // Invoking ukr to check with the second redzone. + swapv( n, x, incx, y, incy ); + } + } + catch(const std::exception& e) + { + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // Show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + + // Reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + //---------------------------------------------------------- + // Compute binary comparison + //---------------------------------------------------------- + computediff( n, x, x_ref, y, y_ref, incx, incy, false ); + +} + + +// Test-case logger : Used to print the test-case details based on parameters +template +class swapvUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); + bool is_memory_test = std::get<4>(str.param); + + std::string str_name = "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name += "_incy_" + testinghelpers::get_value_string(incy); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/trsm/ctrsm/ctrsm_ukr.cpp b/gtestsuite/testsuite/ukr/trsm/ctrsm/ctrsm_ukr.cpp new file mode 100644 index 0000000000..7086e98840 --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/ctrsm/ctrsm_ukr.cpp @@ -0,0 +1,113 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/ref_gemm.h" +#include "ukr/trsm/test_trsm_ukr.h" +#include "level3/trsm/test_trsm.h" + +class ctrsmGenericSmall : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ctrsmGenericSmall); + +#ifndef BLIS_INT_ELEMENT_TYPE + +TEST_P( ctrsmGenericSmall, UKR ) +{ + using T = scomplex; + trsm_small_ker_ft ukr_fp = std::get<0>(GetParam()); + char side = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + char transa = std::get<4>(GetParam()); + gtint_t m = std::get<5>(GetParam()); + gtint_t n = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t lda = std::get<8>(GetParam()); + gtint_t ldb = std::get<9>(GetParam()); + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_small_ukr( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_SCOMPLEX); +} + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small, + ctrsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 'c', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(9), 1), // m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // n + ::testing::Values(scomplex{-1.4, 3.2}, + scomplex{ 2.8, -0.5}, + scomplex{-1.4, 0.0}, + scomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 10, 194), // lda_inc + ::testing::Values(0, 10, 194), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif +#endif + +#endif // ifndef BLIS_INT_ELEMENT_TYPE diff --git a/gtestsuite/testsuite/ukr/trsm/dtrsm/dtrsm_ukr.cpp b/gtestsuite/testsuite/ukr/trsm/dtrsm/dtrsm_ukr.cpp new file mode 100644 index 0000000000..095a9cab7f --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/dtrsm/dtrsm_ukr.cpp @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/ref_gemm.h" +#include "ukr/trsm/test_trsm_ukr.h" +#include "level3/trsm/test_trsm.h" + +class dtrsmGenericNat : + public ::testing::TestWithParam> {}; // is_memory_test + +class dtrsmGenericSmall : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dtrsmGenericNat); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dtrsmGenericSmall); + +TEST_P( dtrsmGenericNat, native_kernel) +{ + using T = double; + dgemmtrsm_ukr_ft ukr_fp = std::get<0>(GetParam()); + char storage = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + gtint_t m = std::get<4>(GetParam()); + gtint_t n = std::get<5>(GetParam()); + gtint_t k = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t ldc = std::get<8>(GetParam()); + bool is_memory_test = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_ukr( ukr_fp, storage, uploa, diaga, m, n, k, alpha, ldc, thresh, is_memory_test ); +} + +TEST_P( dtrsmGenericSmall, small_kernel) +{ + using T = double; + trsm_small_ker_ft ukr_fp = std::get<0>(GetParam()); + char side = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + char transa = std::get<4>(GetParam()); + gtint_t m = std::get<5>(GetParam()); + gtint_t n = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t lda = std::get<8>(GetParam()); + gtint_t ldb = std::get<9>(GetParam()); + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_small_ukr( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_DOUBLE); +} + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_l_zen4_asm_8x24, + dtrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_l_zen4_asm_8x24), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(8), // m + ::testing::Values(24), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_u_zen4_asm_8x24, + dtrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_u_zen4_asm_8x24), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(8), // m + ::testing::Values(24), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small_AVX512, + dtrsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small_AVX512), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(9), 1), // m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // n + ::testing::Values(-3, 3), // alpha + ::testing::Values(0, 10), // lda_inc + ::testing::Values(0, 10), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif + + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_l_haswell_asm_6x8, + dtrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_l_haswell_asm_6x8), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(8), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_u_haswell_asm_6x8, + dtrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_u_haswell_asm_6x8), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(8), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small, + dtrsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(9), 1), // m + ::testing::Range(gtint_t(1), gtint_t(9), 1), // n + ::testing::Values(-3, 3), // alpha + ::testing::Values(0, 10), // lda_inc + ::testing::Values(0, 10), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif +#endif diff --git a/gtestsuite/testsuite/ukr/trsm/strsm/strsm_ukr.cpp b/gtestsuite/testsuite/ukr/trsm/strsm/strsm_ukr.cpp new file mode 100644 index 0000000000..ff88d433bb --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/strsm/strsm_ukr.cpp @@ -0,0 +1,183 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/ref_gemm.h" +#include "ukr/trsm/test_trsm_ukr.h" +#include "level3/trsm/test_trsm.h" + +class strsmGenericNat : + public ::testing::TestWithParam> {}; // is_memory_test + +class strsmGenericSmall : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(strsmGenericNat); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(strsmGenericSmall); + +TEST_P( strsmGenericNat, UKR ) +{ + using T = float; + sgemmtrsm_ukr_ft ukr_fp = std::get<0>(GetParam()); + char storage = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + gtint_t m = std::get<4>(GetParam()); + gtint_t n = std::get<5>(GetParam()); + gtint_t k = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t ldc = std::get<8>(GetParam()); + bool is_memory_test = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_ukr( ukr_fp, storage, uploa, diaga, m, n, k, alpha, ldc, thresh, is_memory_test); +} + +TEST_P( strsmGenericSmall, UKR ) +{ + using T = float; + trsm_small_ker_ft ukr_fp = std::get<0>(GetParam()); + char side = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + char transa = std::get<4>(GetParam()); + gtint_t m = std::get<5>(GetParam()); + gtint_t n = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t lda = std::get<8>(GetParam()); + gtint_t ldb = std::get<9>(GetParam()); + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_small_ukr( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_FLOAT); +} + +#if defined(BLIS_KERNELS_HASWELL) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmtrsm_l_haswell_asm_6x16, + strsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_sgemmtrsm_l_haswell_asm_6x16), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(16), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_sgemmtrsm_u_haswell_asm_6x16, + strsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_sgemmtrsm_u_haswell_asm_6x16), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(16), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); +#endif + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small, + strsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(17), 1), // m + ::testing::Range(gtint_t(1), gtint_t(17), 1), // n + ::testing::Values(-3, 3), // alpha + ::testing::Values(0, 10), // lda_inc + ::testing::Values(0, 10), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif +#endif diff --git a/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h b/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h new file mode 100644 index 0000000000..4acd3affb1 --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h @@ -0,0 +1,494 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include +#include "blis.h" +#include "level3/trsm/trsm.h" +#include "level3/ref_trsm.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" +#include "level3/trsm/test_trsm.h" + + +// function pointer for TRSM small kernels +typedef err_t (*trsm_small_ker_ft) +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl, + bool is_parallel +); + +/* +* Function to test gemmtrsm ukr +*/ +template +static void test_trsm_ukr( FT ukr_fp, char storage, char uploa, char diaga, + gtint_t m, gtint_t n, gtint_t k, T alpha, + gtint_t ldc_inc, double thresh, bool is_memory_test) +{ + gtint_t lda = m, ldb = n; + gtint_t ldc = ldc_inc; + + + // Allocate memory for A10(k*lda) and A11(m*lda) + testinghelpers::ProtectedBuffer a10_buffer( (k+m) * lda * sizeof(T), false, is_memory_test ); + // Allocate aligned memory for B01(k*ldb) and B11(m*ldb) + testinghelpers::ProtectedBuffer b01_buffer( (k+m) * ldb * sizeof(T), true , is_memory_test ); + + + T* a10 = (T*)a10_buffer.greenzone_1; // column major + T* b01 = (T*)b01_buffer.greenzone_1; // row major + + // Initialize vectors with random numbers. + random_generator_with_INF_NAN( a10, uploa, 'c', 'n', -0.1, 0.1, m, (k+m), lda); + random_generator_with_INF_NAN( b01, uploa, 'r', 'n', -0.1, 0.1, (k+m), n, ldb); + + // Get A11(A10 + sizeof(A01)) and B11(B10 + sizeof(B10)) + T* a11 = a10 + (k*lda); + T* b11 = b01 + (k*ldb); + + // make A11 triangular for trsm + testinghelpers::make_triangular( 'c', uploa, m, a11, lda ); + + T* c, *c_ref, *b11_copy; + gtint_t rs_c, cs_c, rs_c_ref, cs_c_ref; + gtint_t size_c, size_c_ref; + + // allocate memory for C according to the storage scheme + if (storage == 'r' || storage == 'R') + { + ldc += n; + rs_c = ldc; + cs_c = 1; + rs_c_ref = rs_c; + cs_c_ref = cs_c; + size_c = ldc * m * sizeof(T); + size_c_ref = size_c; + } + else if (storage == 'c' || storage == 'C') + { + ldc += m; + rs_c = 1; + cs_c = ldc; + rs_c_ref = rs_c; + cs_c_ref = cs_c; + size_c = ldc * n * sizeof(T); + size_c_ref = size_c; + } + else // general storage + { + ldc += m; + + // reference does not support general stride, therefore + // reference is set as column major + rs_c_ref = 1, + cs_c_ref = ldc; + + // for general stride, rs_c and cs_c both are non unit stride + // ldc is used to derieve both rs_c and cs_c + rs_c = ldc; + cs_c = ldc*ldc; + size_c = ldc * n * ldc * sizeof(T); + size_c_ref = ldc * n * 1 * sizeof(T); + } + + // get memory for C and c_ref + testinghelpers::ProtectedBuffer c_buffer(size_c, false, is_memory_test); + c = (T*)c_buffer.greenzone_1; + c_ref = (T*)malloc( size_c_ref ); + + // set c buffers to zero to ensure the unused region of C matrix (extra ldb) is zero + memset(c, 0, size_c); + memset(c_ref, 0, size_c_ref); + + // copy contents of B11 to C and C_ref + for (gtint_t i = 0; i < m; ++i) + { + for (gtint_t j = 0; j < n; ++j) + { + c[j*cs_c + i*rs_c] = b11[i*ldb + j]; + c_ref[j*cs_c_ref + i*rs_c_ref] = b11[i*ldb + j]; + } + } + + // Make A11 diagonal dominant in order to make sure that + // input matrics are solvable + // In case BLIS_ENABLE_TRSM_PREINVERSION is enabled, + // diagonal elements of A11 have to be inverted twice, + // once for making it diagonal dominant, and once for packing with + // inversion, inverting it twice is equivalent to not inverting it at all. + // Therefore, in case of BLIS_ENABLE_TRSM_PREINVERSION, diagonal elements + // of A11 are not inverted. +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = T{1} / a11[i+i*lda]; + } +#endif + + // If A is unit diagonal, set diagonal elements of A11 to 1 + if (diaga == 'u' || diaga == 'U') + { + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = T{1}; + } + } + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + if ( is_memory_test ) + { + // calling gemmtrsm ukr will modify b11 buffer + // create a copy of B11 so that it can be restored + // for the second call of gemmtrsm ukr + b11_copy = (T*)malloc( m*ldb*sizeof(T) ); + memcpy(b11_copy, b11, m*ldb*sizeof(T)); + } + + // Call ukr + ukr_fp + ( + k, + &alpha, + a10, a11, + b01, b11, + c, + rs_c, cs_c, + nullptr, nullptr + ); + if ( is_memory_test ) + { + // set pointers to second buffer + c = (T*)c_buffer.greenzone_2; + a10 = (T*)a10_buffer.greenzone_2; + b01 = (T*)b01_buffer.greenzone_2; + a11 = a10 + (k*lda); + b11 = b01 + (k*ldb); + + // copy data from 1st buffer of A and B to second buffer + memcpy(a10, a10_buffer.greenzone_1, (k+m) * lda * sizeof(T)); + memcpy(b01, b01_buffer.greenzone_1, k * ldb * sizeof(T)); + + memset(c, 0, size_c); + // restore B11 and copy contents of B11 to C + for (gtint_t i = 0; i < m; ++i) + { + for (gtint_t j = 0; j < n; ++j) + { + b11[i*ldb + j] = b11_copy[i*ldb + j]; + c[j*cs_c + i*rs_c] = b11_copy[i*ldb + j]; + } + } + // free b11_copy + free(b11_copy); + + // second call to ukr + ukr_fp( k, &alpha, a10, a11, b01, b11, c, rs_c, cs_c, nullptr, nullptr ); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + // compensate for the trsm per-inversion + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = T{1.0} / a11[i+i*lda]; + } +#endif + + // Call reference implementation to get ref results. + if (storage == 'c' || storage == 'C') + { + testinghelpers::ref_gemm( storage, 'n', 't', m, n, k, T{-1}, + a10, lda, b01, ldb, alpha, c_ref, ldc); + testinghelpers::ref_trsm( storage, 'l', uploa, 'n', diaga, m, n, T{1}, a11, + lda, c_ref, ldc ); + } + else if (storage == 'r' || storage == 'R')// row major + { + testinghelpers::ref_gemm( storage, 't', 'n', m, n, k, T{-1}, + a10, lda, b01, ldb, alpha, c_ref, ldc); + + // convert col major A11 to row Major for TRSM + T temp = T{0}; + for(gtint_t i = 0; i < m; ++i) + { + for(gtint_t j = i; j< m; ++j) + { + temp = a11[i+j*lda]; + a11[i+j*lda] = a11[j+i*lda]; + a11[j+i*lda] = temp; + } + } + + testinghelpers::ref_trsm( storage, 'l', uploa, 'n', diaga, m, n, T{1}, a11, + lda, c_ref, ldc ); + } + else + { + testinghelpers::ref_gemm( 'c', 'n', 't', m, n, k, T{-1}, + a10, lda, b01, ldb, alpha, c_ref, ldc); + testinghelpers::ref_trsm( 'c', 'l', uploa, 'n', diaga, m, n, T{1}, a11, + lda, c_ref, ldc ); + + // there is no equivalent blas call for gen storage, + // in order to compare the gen stored C and column major stored + // create a column major copy of C + T* c_gs = (T*)malloc( ldc * n * 1 * sizeof(T) ); + memset(c_gs, 0, ldc * n * 1 * sizeof(T)); + + for (gtint_t i = 0; i < m; ++i) + { + for (gtint_t j = 0; j < n; ++j) + { + c_gs[i*rs_c_ref + j*cs_c_ref] = c[i*rs_c + j*cs_c]; + } + } + + c = c_gs; + } + + // Compute component-wise error. + computediff( "C", storage, m, n, c, c_ref, ldc, thresh ); + + if(storage != 'r' && storage != 'R' && storage != 'c' && storage != 'C') + { + // free c_gs in case of general stride + free(c); + } + + // free buffers + free(c_ref); +} + +template +static void test_trsm_small_ukr( FT ukr_fp, char side, char uploa, char diaga, + char transa, gtint_t m, gtint_t n, T alpha, gtint_t lda, + gtint_t ldb, double thresh, bool is_memory_test, num_t dt) +{ + // create blis objects + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + + inc_t rs_a = 1; + inc_t cs_a = lda; + inc_t rs_b = 1; + inc_t cs_b = ldb; + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + dim_t mn0_a; + bli_convert_blas_dim1( m, m0 ); + bli_convert_blas_dim1( n, n0 ); + + bli_param_map_netlib_to_blis_side( side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( diaga, &blis_diaga ); + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + bli_obj_init_finish_1x1( dt, (T*)&alpha, &alphao ); + + cs_a += mn0_a; + cs_b += m; + + // Allocate memory for A (col major) + testinghelpers::ProtectedBuffer a_buf( mn0_a * cs_a * sizeof(T), false, is_memory_test ); + // Allocate memory for B (col major) + testinghelpers::ProtectedBuffer b_buf( n * cs_b * sizeof(T), false, is_memory_test ); + + T* a = (T*)a_buf.greenzone_1; + T* b = (T*)b_buf.greenzone_1; + T* b_ref = (T*)malloc( n * cs_b * sizeof(T) ); // col major + + // Initialize buffers with random numbers. + random_generator_with_INF_NAN( a, uploa, 'c', 'n', -0.1, 0.1, mn0_a, mn0_a, cs_a); + random_generator_with_INF_NAN( b, uploa, 'c', 'n', -0.1, 0.1, m, n, cs_b); + + // copy contents of b to b_ref + memcpy(b_ref, b, n * cs_b * sizeof(T)); + + // make A triangular + testinghelpers::make_triangular( 'c', uploa, mn0_a, a, cs_a ); + + // Make A11 diagonal dominant in order to make sure that + // input matrics are solvable + for (gtint_t i = 0; i < mn0_a; i++) + { + a[i+i*cs_a] = T{1} / a[i+i*cs_a]; + } + + bli_obj_init_finish( dt, mn0_a, mn0_a, (T*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (T*)b, rs_b, cs_b, &bo ); + + const struc_t struca = BLIS_TRIANGULAR; + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_struc( struca, &ao ); + + // add signal handler for segmentation fault + testinghelpers::ProtectedBuffer::start_signal_handler(); + try + { + // call trsm small kernel + ukr_fp(blis_side, &alphao, &ao, &bo, NULL, NULL, false); + if ( is_memory_test ) + { + // set A and B pointers to second buffer + a = (T*)a_buf.greenzone_2; + b = (T*)b_buf.greenzone_2; + + // copy data from first buffers of A and B to second buffer + memcpy(b, b_ref, n * cs_b * sizeof(T)); + memcpy(a, (T*)a_buf.greenzone_1, mn0_a * cs_a * sizeof(T)); + bli_obj_init_finish( dt, m0, n0, (T*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, mn0_a, mn0_a, (T*)a, rs_a, cs_a, &ao ); + + // call trsm small kernel + ukr_fp(blis_side, &alphao, &ao, &bo, NULL, NULL, false); + } + } + catch(const std::exception& e) + { + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // show failure in case seg fault was detected + FAIL() << "Memory Test Failed"; + } + // reset to default signal handler + testinghelpers::ProtectedBuffer::stop_signal_handler(); + + // call to reference trsm + testinghelpers::ref_trsm( 'c', side, uploa, transa, diaga, m, n, alpha, a, + cs_a, b_ref, cs_b ); + + computediff( "B", 'c', m, n, b, b_ref, cs_b, thresh ); + + // free memory + free(b_ref); +} + +// Test-case logger : Used to print the test-case details based on parameters +template +class trsmSmallUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const{ + char side = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + char transa = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + T1 alpha = std::get<7>(str.param); + gtint_t lda_inc = std::get<8>(str.param); + gtint_t ldb_inc = std::get<9>(str.param); + bool is_memory_test = std::get<10>(str.param); + + std::string str_name = ""; + str_name += "_side_" + std::string(&side, 1); + str_name += "_uplo_" + std::string(&uploa, 1); + str_name += "_transa_" + std::string(&transa, 1); + str_name += "_diag_" + std::string(&diaga, 1); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + gtint_t mn; + testinghelpers::set_dim_with_side( side, m, n, &mn ); + gtint_t lda = lda_inc + mn; + gtint_t ldb = ldb_inc + m; + str_name += "_lda_i" + std::to_string(lda_inc) + "_" + std::to_string(lda); + str_name += "_ldb_i" + std::to_string(ldb_inc) + "_" + std::to_string(ldb); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; + +template +class trsmNatUKRPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const{ + char storage = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + gtint_t m = std::get<4>(str.param); + gtint_t n = std::get<5>(str.param); + gtint_t k = std::get<6>(str.param); + T1 alpha = std::get<7>(str.param); + gtint_t ldc_inc = std::get<8>(str.param); + bool is_memory_test = std::get<9>(str.param); + + std::string str_name = ""; + str_name += "_stor_" + std::string(&storage, 1); + str_name += "_uplo_" + std::string(&uploa, 1); + str_name += "_diag_" + std::string(&diaga, 1); + str_name += "_alpha_" + testinghelpers::get_value_string(alpha); + str_name += "_m_" + std::to_string(m); + str_name += "_n_" + std::to_string(n); + str_name += "_k_" + std::to_string(k); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + str_name += "_ldc_i" + std::to_string(ldc_inc) + "_" + std::to_string(ldc); + str_name += ( is_memory_test ) ? "_mem_test_enabled" : "_mem_test_disabled"; + return str_name; + } +}; diff --git a/gtestsuite/testsuite/ukr/trsm/ztrsm/ztrsm_ukr.cpp b/gtestsuite/testsuite/ukr/trsm/ztrsm/ztrsm_ukr.cpp new file mode 100644 index 0000000000..42778dba04 --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/ztrsm/ztrsm_ukr.cpp @@ -0,0 +1,272 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/ref_gemm.h" +#include "ukr/trsm/test_trsm_ukr.h" +#include "level3/trsm/test_trsm.h" + +class ztrsmGenericNat : + public ::testing::TestWithParam> {}; // is_memory_test + +class ztrsmGenericSmall : + public ::testing::TestWithParam> {}; // is_memory_test + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ztrsmGenericNat); +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ztrsmGenericSmall); + +#ifndef BLIS_INT_ELEMENT_TYPE + +TEST_P( ztrsmGenericNat, UKR ) +{ + using T = dcomplex; + zgemmtrsm_ukr_ft ukr_fp = std::get<0>(GetParam()); + char storage = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + gtint_t m = std::get<4>(GetParam()); + gtint_t n = std::get<5>(GetParam()); + gtint_t k = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t ldc = std::get<8>(GetParam()); + bool is_memory_test = std::get<9>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + // Threshold adjustment +#ifdef BLIS_INT_ELEMENT_TYPE + double adj = 1.0; +#else + double adj = 1.6; +#endif + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = adj*3*m*testinghelpers::getEpsilon(); + + test_trsm_ukr( ukr_fp, storage, uploa, diaga, m, n, k, alpha, ldc, thresh, is_memory_test); +} + +TEST_P( ztrsmGenericSmall, UKR ) +{ + using T = dcomplex; + trsm_small_ker_ft ukr_fp = std::get<0>(GetParam()); + char side = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + char transa = std::get<4>(GetParam()); + gtint_t m = std::get<5>(GetParam()); + gtint_t n = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t lda = std::get<8>(GetParam()); + gtint_t ldb = std::get<9>(GetParam()); + bool is_memory_test = std::get<10>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite trsm.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (m == 0 || n == 0 || alpha == testinghelpers::ZERO()) + thresh = 0.0; + else + thresh = 3*m*testinghelpers::getEpsilon(); + + test_trsm_small_ukr( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_DCOMPLEX); +} + +#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmtrsm_l_zen4_asm_4x12, + ztrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_zgemmtrsm_l_zen4_asm_4x12), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(4), // m + ::testing::Values(12), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmtrsm_u_zen4_asm_4x12, + ztrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_zgemmtrsm_u_zen4_asm_4x12), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(4), // m + ::testing::Values(12), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small_AVX512, + ztrsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small_AVX512), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(5), 1), // m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // n + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 10, 194), // lda_inc + ::testing::Values(0, 10, 194), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif + +#endif + + +#if defined(BLIS_KERNELS_ZEN) && defined(GTEST_AVX2FMA3) +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmtrsm_l_zen_asm_2x6, + ztrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_zgemmtrsm_l_zen_asm_2x6), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(2), // m + ::testing::Values(6), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +INSTANTIATE_TEST_SUITE_P ( + bli_zgemmtrsm_u_zen_asm_2x6, + ztrsmGenericNat, + ::testing::Combine( + ::testing::Values(bli_zgemmtrsm_u_zen_asm_2x6), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(2), // m + ::testing::Values(6), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 9, 53), // ldc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmNatUKRPrint()) +); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +INSTANTIATE_TEST_SUITE_P ( + bli_trsm_small, + ztrsmGenericSmall, + ::testing::Combine( + ::testing::Values(bli_trsm_small), // ker_ptr + ::testing::Values('l', 'r'), // side + ::testing::Values('l', 'u'), // uplo + ::testing::Values('n', 'u'), // diaga + ::testing::Values('n', 'c', 't'), // transa + ::testing::Range(gtint_t(1), gtint_t(5), 1), // m + ::testing::Range(gtint_t(1), gtint_t(5), 1), // n + ::testing::Values(dcomplex{-1.4, 3.2}, + dcomplex{ 2.8, -0.5}, + dcomplex{-1.4, 0.0}, + dcomplex{ 0.0, -1.9}), // alpha + ::testing::Values(0, 10, 194), // lda_inc + ::testing::Values(0, 10, 194), // ldb_inc + ::testing::Values(false, true) // is_memory_test + ), + (::trsmSmallUKRPrint()) +); +#endif +#endif + +#endif // ifndef BLIS_INT_ELEMENT_TYPE diff --git a/gtestsuite/testsuite/util/asumv/asumv.h b/gtestsuite/testsuite/util/asumv/asumv.h new file mode 100644 index 0000000000..2025b1709f --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/asumv.h @@ -0,0 +1,160 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +/** + * @brief computes the sum of the absolute values of the fundamental elements + * of vector x. + * + * @param[in] n vector length + * @param[in] x pointer which points to the first element of x + * @param[in] incx increment of x + * @return sum of the absolute values of the fundamental elements of x + * + * + */ + +template::real_type> +static RT asumv_(gtint_t n, T* x, gtint_t incx){ + if constexpr (std::is_same::value) + return sasum_( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dasum_( &n, x, &incx ); + else if constexpr (std::is_same::value) + return scasum_( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dzasum_( &n, x, &incx ); + else + throw std::runtime_error("Error in testsuite/util/asumv.h: Invalid typename in asumv_()."); +} + +template::real_type> +static RT asumv_blis_impl(gtint_t n, T* x, gtint_t incx){ + if constexpr (std::is_same::value) + return sasum_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dasum_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return scasum_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dzasum_blis_impl( &n, x, &incx ); + else + throw std::runtime_error("Error in testsuite/util/asumv.h: Invalid typename in asumv_blis_impl()."); +} + +template::real_type> +static RT cblas_asumv(gtint_t n, T* x, gtint_t incx){ + if constexpr (std::is_same::value) + return cblas_sasum( n, x, incx ); + else if constexpr (std::is_same::value) + return cblas_dasum( n, x, incx ); + else if constexpr (std::is_same::value) + return cblas_scasum( n, x, incx ); + else if constexpr (std::is_same::value) + return cblas_dzasum( n, x, incx ); + else + throw std::runtime_error("Error in testsuite/util/asumv.h: Invalid typename in cblas_asumv()."); +} + +template::real_type> +static RT typed_asumv(gtint_t n, T* x, gtint_t incx){ + RT asum; + if constexpr (std::is_same::value) + bli_sasumv(n, x, incx, &asum); + else if constexpr (std::is_same::value) + bli_dasumv(n, x, incx, &asum); + else if constexpr (std::is_same::value) + bli_casumv(n, x, incx, &asum); + else if constexpr (std::is_same::value) + bli_zasumv(n, x, incx, &asum); + else + throw std::runtime_error("Error in testsuite/util/asumv.h: Invalid typename in cblas_asumv()."); + return asum; +} + +template::real_type> +static RT asumv(gtint_t n, T* x, gtint_t incx) +{ + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + +#ifdef TEST_BLAS + return asumv_(n, x, incx); +#elif TEST_BLAS_BLIS_IMPL + return asumv_blis_impl(n, x, incx); +#elif TEST_CBLAS + return cblas_asumv(n, x, incx); +#elif TEST_BLIS_TYPED + return typed_asumv(n, x, incx); +#else + throw std::runtime_error("Error in testsuite/util/asumv.h: No interfaces are set to be tested."); +#endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif +} diff --git a/gtestsuite/testsuite/util/asumv/asumv_IIT_ERS.cpp b/gtestsuite/testsuite/util/asumv/asumv_IIT_ERS.cpp new file mode 100644 index 0000000000..16d153c195 --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/asumv_IIT_ERS.cpp @@ -0,0 +1,229 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class asumv_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(asumv_IIT_ERS, TypeParam); + +using namespace testinghelpers::IIT; + +#if defined(TEST_BLAS_LIKE) || defined(TEST_CBLAS) + +/* + BLAS Early Return Scenarios(ERS): + + ASUMV is expected to return early in the following cases: + 1. n <= 0 + 2. inc <= 0 +*/ + +// n < 0, with non-unit stride +TYPED_TEST(asumv_IIT_ERS, n_lt_zero_nonUnitStride) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = -1; + gtint_t inc = 5; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( invalid_n, nullptr, inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking asumV with an invalid value of n. + asum = asumv( invalid_n, x.data(), inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} + +// n == 0, with non-unit stride +TYPED_TEST(asumv_IIT_ERS, n_eq_zero_nonUnitStride) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = 0; + gtint_t inc = 5; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( invalid_n, nullptr, inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, inc ); + + // Invoking asumV with an invalid value of n. + asum = asumv( invalid_n, x.data(), inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} + +// n < 0, with unit stride +TYPED_TEST(asumv_IIT_ERS, n_lt_zero_unitStride) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = -1; + gtint_t unit_inc = 1; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( invalid_n, nullptr, unit_inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking asumV with an invalid value of n. + asum = asumv( invalid_n, x.data(), unit_inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} + +// n == 0, with unit stride +TYPED_TEST(asumv_IIT_ERS, n_eq_zero_unitStride) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = 0; + gtint_t unit_inc = 1; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( invalid_n, nullptr, unit_inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, unit_inc ); + + // Invoking asumV with an invalid value of n. + asum = asumv( invalid_n, x.data(), unit_inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} + +// inc < 0 +TYPED_TEST(asumv_IIT_ERS, inc_lt_0) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_inc = -1; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( N, nullptr, invalid_inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + + // Invoking asumV with an invalid value of n. + asum = asumv( N, x.data(), invalid_inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} + +// inc == 0 +TYPED_TEST(asumv_IIT_ERS, inc_eq_0) +{ + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_inc = 0; + // Initialize asum (BLIS output) to garbage value. + RT asum = RT{-7.3}; + // Initialize the expected output to zero. + RT asum_ref; + testinghelpers::initzero(asum_ref); + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + asum = asumv( N, nullptr, invalid_inc ); + // Computing the difference. + computediff( "asum", asum, asum_ref ); + + // Test with all arguments correct except for the value we are choosing to test. + // Initialize x vector with random numbers. + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + + // Invoking asumV with an invalid value of n. + asum = asumv( N, x.data(), invalid_inc ); + + // Computing the difference. + computediff( "asum", asum, asum_ref ); +} +#endif diff --git a/gtestsuite/testsuite/util/asumv/dasumv_evt.cpp b/gtestsuite/testsuite/util/asumv/dasumv_evt.cpp new file mode 100644 index 0000000000..192a283b06 --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/dasumv_evt.cpp @@ -0,0 +1,127 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" + +class dasumvEVT : + public ::testing::TestWithParam> {}; // jx_exval + +TEST_P( dasumvEVT, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // index of extreme value for x: + gtint_t xi = std::get<2>(GetParam()); + // extreme value for x: + double ix_exval = std::get<3>(GetParam()); + // index of extreme value for x: + gtint_t xj = std::get<4>(GetParam()); + // extreme value for x: + double jx_exval = std::get<5>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || incx <= 0) + thresh = 0.0; + else + thresh = n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_asumv( n, incx, xi, ix_exval, xj, jx_exval, thresh ); +} + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// EVT with unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_unitStride, + dasumvEVT, + ::testing::Combine( + // n: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(1) ), + // xi: first index to set extreme value in x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // ix_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // xj: second index to set extreme value in x. + ::testing::Values( gtint_t(13) ), + // jx_exval: extreme value for x. + // jx_exval = 1.0 tests for the vector with only one extreme value. + ::testing::Values( 1.0, NaN, Inf, -Inf ) + ), + ::asumvEVTPrint() + ); + +// EVT with non-unit stride vector containing Infs/NaNs. +INSTANTIATE_TEST_SUITE_P( + vec_nonUnitStride, + dasumvEVT, + ::testing::Combine( + // n: size of vector. + ::testing::Values( gtint_t(55) ), + // incx: stride of x vector. + ::testing::Values( gtint_t(3) ), + // xi: first index to set extreme value in x. + ::testing::Values( gtint_t(1), gtint_t(27), gtint_t(51) ), + // ix_exval: extreme value for x. + ::testing::Values( NaN, Inf, -Inf ), + // xj: second index to set extreme value in x. + ::testing::Values( gtint_t(13) ), + // jx_exval: extreme value for x. + // jx_exval = 1.0 tests for the vector with only one extreme value. + ::testing::Values( 1.0, NaN, Inf, -Inf ) + ), + ::asumvEVTPrint() + ); diff --git a/gtestsuite/testsuite/util/asumv/dasumv_generic.cpp b/gtestsuite/testsuite/util/asumv/dasumv_generic.cpp new file mode 100644 index 0000000000..344029c23d --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/dasumv_generic.cpp @@ -0,0 +1,152 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" + +class dasumvGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( dasumvGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || incx <= 0) + thresh = 0.0; + else + thresh = n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_asumv( n, incx, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + dasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::asumvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrement, + dasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(3) + ) + ), + ::asumvGenericPrint() + ); + +// @note: ASUMV is supposed to set sum as 0 and return early in case incx <= 0, +// but since it is currently not following this, failures are being observed. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrement, + dasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(-1), + gtint_t(-2), + gtint_t(-3) + ) + ), + ::asumvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/util/asumv/dzasumv_generic.cpp b/gtestsuite/testsuite/util/asumv/dzasumv_generic.cpp new file mode 100644 index 0000000000..439c39cd1c --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/dzasumv_generic.cpp @@ -0,0 +1,153 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" + +class dzasumvGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( dzasumvGeneric, API ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0 || incx <= 0) + thresh = 0.0; + else + thresh = n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_asumv( n, incx, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + dzasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::asumvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrement, + dzasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(3) + ) + ), + ::asumvGenericPrint() + ); + +// @note: ASUMV is supposed to set sum as 0 and return early in case incx <= 0, +// but since it is currently not following this, failures are being observed. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrement, + dzasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(-1), + gtint_t(-2), + gtint_t(-3) + ) + ), + ::asumvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/util/asumv/sasumv_generic.cpp b/gtestsuite/testsuite/util/asumv/sasumv_generic.cpp new file mode 100644 index 0000000000..02367c7611 --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/sasumv_generic.cpp @@ -0,0 +1,152 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" + +class sasumvGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( sasumvGeneric, API ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0 || incx <= 0) + thresh = 0.0; + else + thresh = n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_asumv( n, incx, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + sasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::asumvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrement, + sasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(3) + ) + ), + ::asumvGenericPrint() + ); + +// @note: ASUMV is supposed to set sum as 0 and return early in case incx <= 0, +// but since it is currently not following this, failures are being observed. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrement, + sasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(-1), + gtint_t(-2), + gtint_t(-3) + ) + ), + ::asumvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/util/asumv/scasumv_generic.cpp b/gtestsuite/testsuite/util/asumv/scasumv_generic.cpp new file mode 100644 index 0000000000..a978a35fb5 --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/scasumv_generic.cpp @@ -0,0 +1,153 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_asumv.h" + +class scasumvGeneric : + public ::testing::TestWithParam> {}; + +TEST_P( scasumvGeneric, API ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + + // Set the threshold for the errors: + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0 || incx <= 0) + thresh = 0.0; + else + thresh = n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_asumv( n, incx, thresh ); +} + +INSTANTIATE_TEST_SUITE_P( + unitPositiveIncrement, + scasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(1) + ) + ), + ::asumvGenericPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + nonUnitPositiveIncrement, + scasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(2), + gtint_t(3) + ) + ), + ::asumvGenericPrint() + ); + +// @note: ASUMV is supposed to set sum as 0 and return early in case incx <= 0, +// but since it is currently not following this, failures are being observed. +#ifndef TEST_BLIS_TYPED +INSTANTIATE_TEST_SUITE_P( + negativeIncrement, + scasumvGeneric, + ::testing::Combine( + // m: size of vector. + ::testing::Values( + gtint_t( 1), + gtint_t( 2), + gtint_t( 3), + gtint_t( 5), + gtint_t( 7), + gtint_t( 9), + gtint_t(10), + gtint_t(15), + gtint_t(20), + gtint_t(55), + gtint_t(99) + ), + // incx: stride of x vector. + ::testing::Values( + gtint_t(-1), + gtint_t(-2), + gtint_t(-3) + ) + ), + ::asumvGenericPrint() + ); +#endif diff --git a/gtestsuite/testsuite/util/asumv/test_asumv.h b/gtestsuite/testsuite/util/asumv/test_asumv.h new file mode 100644 index 0000000000..0590a99793 --- /dev/null +++ b/gtestsuite/testsuite/util/asumv/test_asumv.h @@ -0,0 +1,146 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "asumv.h" +#include +#include "util/ref_asumv.h" +#include "inc/check_error.h" + +/** + * @brief Used for generic tests with random values in x. + */ +template +void test_asumv( gtint_t n, gtint_t incx, double thresh ) +{ + // Get real type from T. + using RT = typename testinghelpers::type_info::real_type; + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + RT asum_ref = testinghelpers::ref_asumv( n, x.data(), incx ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + RT asum = asumv(n, x.data(), incx); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + computediff( "asum", asum, asum_ref, thresh ); +} + +/** + * @brief Used to insert Exception Values in x vector. + */ +template +void test_asumv( gtint_t n, gtint_t incx, gtint_t xi, T ix_exval, + gtint_t xj, T jx_exval, double thresh ) +{ + // Get real type from T. + using RT = typename testinghelpers::type_info::real_type; + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + + // Update the value at index xi to an extreme value, ix_exval. + if ( -1 < xi && xi < n ) x[xi * incx] = ix_exval; + else return; + + // Update the value at index xj to an extreme value, jx_exval. + if ( -1 < xi && xi < n ) x[xj * incx] = jx_exval; + else return; + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + RT asum_ref = testinghelpers::ref_asumv( n, x.data(), incx ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + RT asum = asumv(n, x.data(), incx); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + computediff( "asum", asum, asum_ref, thresh, true ); +} + + +// Test-case logger : Used to print the test-case details based on parameters +class asumvGenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + return str_name; + } +}; + +template +class asumvEVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + gtint_t xi = std::get<2>(str.param); + T ix_exval = std::get<3>(str.param); + gtint_t xj = std::get<4>(str.param); + T jx_exval = std::get<5>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_X_" + std::to_string(xi); + str_name = str_name + "_" + testinghelpers::get_value_string(ix_exval); + str_name = str_name + "_X_" + std::to_string(xj); + str_name = str_name + "_" + testinghelpers::get_value_string(jx_exval); + return str_name; + } +}; diff --git a/gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/dnrm2_evt.cpp similarity index 75% rename from gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp rename to gtestsuite/testsuite/util/nrm2/dnrm2_evt.cpp index 32386593d0..f23a5611f7 100644 --- a/gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp +++ b/gtestsuite/testsuite/util/nrm2/dnrm2_evt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class dnrm2_EVT : +class dnrm2EVT : public ::testing::TestWithParam> {}; -TEST_P( dnrm2_EVT, EVT ) +TEST_P( dnrm2EVT, API ) { using T = double; //---------------------------------------------------------- @@ -62,41 +62,6 @@ TEST_P( dnrm2_EVT, EVT ) test_nrm2(n, incx, i, iexval, j, jexval); } -// Prints the test case combination -class dnrm2_TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - // vector length: - gtint_t n = std::get<0>(str.param); - // stride size for x: - gtint_t incx = std::get<1>(str.param); - // index with extreme value iexval. - gtint_t i = std::get<2>(str.param); - double iexval = std::get<3>(str.param); - // index with extreme value jexval. - gtint_t j = std::get<4>(str.param); - double jexval = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dnrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dnrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dnormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_i" + std::to_string(i); - std::string iexval_str = testinghelpers::get_value_string(iexval); - str_name = str_name + "_" + iexval_str; - str_name = str_name + "_j" + std::to_string(j); - std::string jexval_str = testinghelpers::get_value_string(jexval); - str_name = str_name + "_" + jexval_str; - return str_name; - } -}; - static double NaN = std::numeric_limits::quiet_NaN(); static double Inf = std::numeric_limits::infinity(); @@ -114,7 +79,7 @@ static double Inf = std::numeric_limits::infinity(); // of having first a NaN and then an Inf and so on. INSTANTIATE_TEST_SUITE_P( scalar, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(3)), @@ -127,12 +92,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(2), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); INSTANTIATE_TEST_SUITE_P( vector_F8, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(8)), @@ -145,14 +110,14 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(6), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F4), we use n = 12 // and ensure that the extreme values are on or after index 8. INSTANTIATE_TEST_SUITE_P( vector_F4, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(12)), @@ -165,7 +130,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(11), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); // Now let's check the combination of a vectorized path and @@ -173,7 +138,7 @@ INSTANTIATE_TEST_SUITE_P( // to check that the checks are integrated correctly. INSTANTIATE_TEST_SUITE_P( vector_scalar, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(10)), @@ -186,7 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(8), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); // Multithreading unit tester @@ -210,7 +175,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( EVT_MT_Unit_Tester, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(256), @@ -234,7 +199,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(4, 17, 125, 201), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); // Instantiator if AOCL_DYNAMIC is enabled @@ -245,7 +210,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( EVT_MT_AOCL_DYNAMIC, - dnrm2_EVT, + dnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(2950000), @@ -262,5 +227,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1500000, 2500000), ::testing::Values(-Inf, NaN) ), - ::dnrm2_TestPrint() + ::nrm2EVTPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp index 422f5bfe76..aded14ffa4 100644 --- a/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class dnrm2Test : +class dnrm2Generic : public ::testing::TestWithParam> {}; -TEST_P( dnrm2Test, RandomData ) +TEST_P( dnrm2Generic, API ) { using T = double; //---------------------------------------------------------- @@ -51,7 +51,14 @@ TEST_P( dnrm2Test, RandomData ) gtint_t incx = std::get<1>(GetParam()); // Set the threshold for the errors: - double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite nrm2.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = std::sqrt(n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -59,27 +66,6 @@ TEST_P( dnrm2Test, RandomData ) test_nrm2( n, incx, thresh ); } -// Prints the test case combination -class dnrm2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dnrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dnrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_dnormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - return str_name; - } -}; - /** * dnrm2 implementation is composed by two parts: * - vectorized path for n>4 @@ -90,7 +76,7 @@ class dnrm2TestPrint { INSTANTIATE_TEST_SUITE_P( AT_1T, - dnrm2Test, + dnrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1), // trivial case n=1 @@ -112,7 +98,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dnrm2TestPrint() + ::nrm2GenericPrint() ); // Multithreading unit tester @@ -139,7 +125,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( AT_MT_Unit_Tester, - dnrm2Test, + dnrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(256), @@ -162,7 +148,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dnrm2TestPrint() + ::nrm2GenericPrint() ); // Instantiator if AOCL_DYNAMIC is enabled @@ -173,7 +159,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( AT_MT_AOCL_DYNAMIC, - dnrm2Test, + dnrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(2950000), @@ -188,5 +174,5 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dnrm2TestPrint() + ::nrm2GenericPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/dznrm2_evt.cpp similarity index 76% rename from gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp rename to gtestsuite/testsuite/util/nrm2/dznrm2_evt.cpp index 993859265c..98065557b8 100644 --- a/gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp +++ b/gtestsuite/testsuite/util/nrm2/dznrm2_evt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class dznrm2_EVT : +class dznrm2EVT : public ::testing::TestWithParam>{}; -TEST_P( dznrm2_EVT, EVT ) +TEST_P( dznrm2EVT, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -62,41 +62,6 @@ TEST_P( dznrm2_EVT, EVT ) test_nrm2(n, incx, i, iexval, j, jexval); } -// Prints the test case combination -class dznrm2_TestPrint{ -public: - std::string operator()( - testing::TestParamInfo> str) const { - // vector length: - gtint_t n = std::get<0>(str.param); - // stride size for x: - gtint_t incx = std::get<1>(str.param); - // index with extreme value iexval. - gtint_t i = std::get<2>(str.param); - dcomplex iexval = std::get<3>(str.param); - // index with extreme value jexval. - gtint_t j = std::get<4>(str.param); - dcomplex jexval = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dznrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dznrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_znormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_i" + std::to_string(i); - std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag); - str_name = str_name + iexval_str; - str_name = str_name + "_j" + std::to_string(j); - std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag); - str_name = str_name + jexval_str; - return str_name; - } -}; - static double NaN = std::numeric_limits::quiet_NaN(); static double Inf = std::numeric_limits::infinity(); /** @@ -113,7 +78,7 @@ static double Inf = std::numeric_limits::infinity(); // of having first a NaN and then an Inf and so on. INSTANTIATE_TEST_SUITE_P( scalar, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(2)), @@ -126,12 +91,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1), ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); INSTANTIATE_TEST_SUITE_P( vector_F4, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(4)), @@ -144,14 +109,14 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(3), ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F2), we use n = 6 // and ensure that the extreme values are on or after index 4. INSTANTIATE_TEST_SUITE_P( vector_F2, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(6)), @@ -164,7 +129,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(5), ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); // Now let's check the combination of a vectorized path and @@ -172,7 +137,7 @@ INSTANTIATE_TEST_SUITE_P( // to check that the checks are integrated correctly. INSTANTIATE_TEST_SUITE_P( vector_scalar, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(7)), @@ -185,7 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(6), ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); // Mutlthreading Unit Tester @@ -210,7 +175,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( EVT_MT_Unit_Tester, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(128), @@ -234,7 +199,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(6, 25, 64, 127), ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); // Instantiator if AOCL_DYNAMIC is enabled @@ -245,7 +210,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( EVT_MT_AOCL_DYNAMIC, - dznrm2_EVT, + dznrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1530000), @@ -260,5 +225,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1100000, 1500000), ::testing::Values(dcomplex{NaN, Inf}, dcomplex{-Inf, NaN}, dcomplex{Inf, 0.0}) ), - ::dznrm2_TestPrint() + ::nrm2EVTPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp index a0fb186ccc..24f70881e3 100644 --- a/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class dznrm2Test : +class dznrm2Generic : public ::testing::TestWithParam> {}; -TEST_P( dznrm2Test, RandomData ) +TEST_P( dznrm2Generic, API ) { using T = dcomplex; //---------------------------------------------------------- @@ -51,7 +51,15 @@ TEST_P( dznrm2Test, RandomData ) gtint_t incx = std::get<1>(GetParam()); // Set the threshold for the errors: - double thresh = 3*testinghelpers::getEpsilon(); + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = std::sqrt(n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -59,27 +67,6 @@ TEST_P( dznrm2Test, RandomData ) test_nrm2(n, incx, thresh); } -// Prints the test case combination -class dznrm2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dznrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dznrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_znormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - return str_name; - } -}; - /** * dznrm2 implementation is composed by two parts: * - vectorized path for n>2 @@ -89,7 +76,7 @@ class dznrm2TestPrint { */ INSTANTIATE_TEST_SUITE_P( AT_1T, - dznrm2Test, + dznrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1), // trivial case n=1 @@ -111,7 +98,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dznrm2TestPrint() + ::nrm2GenericPrint() ); // Multithreading unit tester @@ -133,7 +120,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( AT_MT_Unit_Tester, - dznrm2Test, + dznrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(128), @@ -155,7 +142,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dznrm2TestPrint() + ::nrm2GenericPrint() ); // Instantiator if AOCL_DYNAMIC is enabled @@ -166,7 +153,7 @@ INSTANTIATE_TEST_SUITE_P( */ INSTANTIATE_TEST_SUITE_P( AT_MT_AOCL_DYNAMIC, - dznrm2Test, + dznrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1530000), @@ -179,5 +166,5 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::dznrm2TestPrint() + ::nrm2GenericPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/nrm2.h b/gtestsuite/testsuite/util/nrm2/nrm2.h index 9693a70aa0..45780998b5 100644 --- a/gtestsuite/testsuite/util/nrm2/nrm2.h +++ b/gtestsuite/testsuite/util/nrm2/nrm2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,22 +36,23 @@ #include "blis.h" #include "common/testing_helpers.h" +#include "inc/check_error.h" /** * @brief Computes the Euclidean norm of x. - * + * * Euclidean norm of a vector x is defined as nrm2 = sqrt(x'*x). * In case a vector element is NaN, nrm2 must be NaN. * In case a vector element is inf, and there is no element which is NaN, nrm2 must be inf. * If n <= 0, nrm2 returns zero. * If incx = 0, nrm2 returns sqrt(n*abs(x[0])**2). - * + * * @param[in] n vector length * @param[in] x pointer which points to the first element of x * @param[in] incx increment of x * @return the Euclidean norm of x - * - * + * + * */ template::real_type> @@ -68,6 +69,20 @@ static RT nrm2_(gtint_t n, T* x, gtint_t incx){ throw std::runtime_error("Error in testsuite/level1/nrm2.h: Invalid typename in nrm2_()."); } +template::real_type> +static RT nrm2_blis_impl(gtint_t n, T* x, gtint_t incx){ + if constexpr (std::is_same::value) + return snrm2_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dnrm2_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return scnrm2_blis_impl( &n, x, &incx ); + else if constexpr (std::is_same::value) + return dznrm2_blis_impl( &n, x, &incx ); + else + throw std::runtime_error("Error in testsuite/level1/nrm2.h: Invalid typename in nrm2_blis_impl()."); +} + template::real_type> static RT cblas_nrm2(gtint_t n, T* x, gtint_t incx){ if constexpr (std::is_same::value) @@ -101,8 +116,26 @@ static RT typed_nrm2(gtint_t n, T* x, gtint_t incx){ template::real_type> static RT nrm2(gtint_t n, T* x, gtint_t incx) { + +#ifdef TEST_INPUT_ARGS + // Create copy of scalar input values so we can check that they are not altered. + gtint_t n_cpy = n; + gtint_t incx_cpy = incx; + + // Create copy of input arrays so we can check that they are not altered. + T* x_cpy = nullptr; + gtint_t size_x = testinghelpers::buff_dim( n, incx ); + if (x && size_x > 0) + { + x_cpy = new T[size_x]; + memcpy( x_cpy, x, size_x * sizeof( T ) ); + } +#endif + #ifdef TEST_BLAS return nrm2_(n, x, incx); +#elif TEST_BLAS_BLIS_IMPL + return nrm2_blis_impl(n, x, incx); #elif TEST_CBLAS return cblas_nrm2(n, x, incx); #elif TEST_BLIS_TYPED @@ -110,4 +143,23 @@ static RT nrm2(gtint_t n, T* x, gtint_t incx) #else throw std::runtime_error("Error in testsuite/level1/axpyv.h: No interfaces are set to be tested."); #endif + +#ifdef TEST_INPUT_ARGS + //---------------------------------------------------------- + // Check scalar inputs have not been modified. + //---------------------------------------------------------- + + computediff( "n", n, n_cpy ); + computediff( "incx", incx, incx_cpy ); + + //---------------------------------------------------------- + // Bitwise-wise check array inputs have not been modified. + //---------------------------------------------------------- + + if (x && size_x > 0) + { + computediff( "x", n, x, x_cpy, incx, true ); + delete[] x_cpy; + } +#endif } diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_IIT_ERS.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_IIT_ERS.cpp new file mode 100644 index 0000000000..bd8699be07 --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/nrm2_IIT_ERS.cpp @@ -0,0 +1,128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" +#include "common/wrong_inputs_helpers.h" + +/* + Early Return Scenarios(ERS) for BLAS/CBLAS compliance : + + The NRM2 API is expected to return early in the following cases: + 1. When n <= 0 (BLAS compliance). +*/ + +template +class nrm2_IIT_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(nrm2_IIT_ERS, TypeParam); + +// Adding namespace to get default parameters from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +// Early return n < 0. +TYPED_TEST(nrm2_IIT_ERS, n_lt_zero_nonUnitStrides) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = -1; + // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. + RT blis_norm = -4.2; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + blis_norm = nrm2(invalid_n, nullptr, INC); + computediff("norm", blis_norm, 0.0); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + blis_norm = nrm2(invalid_n, x.data(), INC); + computediff("norm", blis_norm, 0.0); +} + +// Early return n = 0. +TYPED_TEST(nrm2_IIT_ERS, n_eq_zero_nonUnitStrides) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = 0; + // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. + RT blis_norm = 19.0; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + blis_norm = nrm2(invalid_n, nullptr, INC); + computediff("norm", blis_norm, 0.0); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, INC ); + blis_norm = nrm2(invalid_n, x.data(), INC); + computediff("norm", blis_norm, 0.0); +} + +// Early return n < 0. +TYPED_TEST(nrm2_IIT_ERS, n_lt_zero_unitStrides) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = -1; + // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. + RT blis_norm = -4.2; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + blis_norm = nrm2(invalid_n, nullptr, 1); + computediff("norm", blis_norm, 0.0); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + blis_norm = nrm2(invalid_n, x.data(), 1); + computediff("norm", blis_norm, 0.0); +} + +// Early return n = 0. +TYPED_TEST(nrm2_IIT_ERS, n_eq_zero_unitStrides) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t invalid_n = 0; + // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. + RT blis_norm = 19.0; + + // Test with nullptr for all suitable arguments that shouldn't be accessed. + blis_norm = nrm2(invalid_n, nullptr, 1); + computediff("norm", blis_norm, 0.0); + + // Test with all arguments correct except for the value we are choosing to test. + // Defining the x vector + std::vector x = testinghelpers::get_random_vector( -10, 10, N, 1 ); + blis_norm = nrm2(invalid_n, x.data(), 1); + computediff("norm", blis_norm, 0.0); +} diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_extreme_cases.cpp similarity index 68% rename from gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp rename to gtestsuite/testsuite/util/nrm2/nrm2_extreme_cases.cpp index c4e09cd83e..edb2383613 100644 --- a/gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp +++ b/gtestsuite/testsuite/util/nrm2/nrm2_extreme_cases.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -37,31 +37,10 @@ /** * Testing edge input parameters. - * - * zero n should return 0. + * * zero incx should return sqrt(n*abs(x[0])**2). */ -// Early return. -template -class nrm2_ERS : public ::testing::Test {}; -typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(nrm2_ERS, TypeParam); - -TYPED_TEST(nrm2_ERS, zero_n) { - using T = TypeParam; - using RT = typename testinghelpers::type_info::real_type; - gtint_t n = 0; - gtint_t incx = 1; - // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. - RT blis_norm = 19.0; - // using nullptr since x should not be accessed anyway. - // If "x" is accessed before return then nrm2 would segfault. - blis_norm = nrm2(n, nullptr, incx); - RT ref_norm = testinghelpers::ref_nrm2(n, nullptr, incx); - computediff(blis_norm, ref_norm); -} - // Edge case where it actually does not return early. // Since there are 2 different paths, vectorized and scalar, // we break this into 2 tests, once for each case. @@ -72,8 +51,8 @@ TYPED_TEST_SUITE(nrm2_EIC, TypeParam); TYPED_TEST(nrm2_EIC, zero_incx_scalar) { using T = TypeParam; - using RT = typename testinghelpers::type_info::real_type; - gtint_t n = 2; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2; gtint_t incx = 0; std::vector x(n); for (auto &xi : x) @@ -85,13 +64,13 @@ TYPED_TEST(nrm2_EIC, zero_incx_scalar) { RT blis_norm = 19.0; blis_norm = nrm2(n, x.data(), incx); RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); - computediff(blis_norm, ref_norm); + computediff("norm", blis_norm, ref_norm); } TYPED_TEST(nrm2_EIC, zero_incx_vectorized) { using T = TypeParam; - using RT = typename testinghelpers::type_info::real_type; - gtint_t n = 64; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 64; gtint_t incx = 0; std::vector x(n); for (auto &xi : x) @@ -103,7 +82,7 @@ TYPED_TEST(nrm2_EIC, zero_incx_vectorized) { RT blis_norm = 19.0; blis_norm = nrm2(n, x.data(), incx); RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); - computediff(blis_norm, ref_norm); + computediff("norm", blis_norm, ref_norm); } /* @@ -126,5 +105,5 @@ TYPED_TEST( nrm2_EIC, zero_incx_MT ) { x[0] = T{2.0}*x[0]; RT blis_norm = nrm2(n, x.data(), incx); RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); - computediff(blis_norm, ref_norm); + computediff("norm", blis_norm, ref_norm); } diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp deleted file mode 100644 index 3a702de62b..0000000000 --- a/gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_nrm2.h" -#include "common/wrong_inputs_helpers.h" - -/** - * Testing invalid/incorrect input parameters. - * - * That is only negative n for this API. Zero incx and zero n is allowed. -*/ -template -class nrm2_IIT : public ::testing::Test {}; -typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(nrm2_IIT, TypeParam); - -// Adding namespace to get default parameters from testinghelpers/common/wrong_input_helpers.h. -using namespace testinghelpers::IIT; - -TYPED_TEST(nrm2_IIT, negative_n) { - using T = TypeParam; - using RT = typename testinghelpers::type_info::real_type; - T x = T{-3.7}; - // initialize blis norm with garbage. - RT blis_norm = -4.2; - blis_norm = nrm2(-2, &x, INC); - - computediff(blis_norm, 0.0); -} diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp index 22e0141292..3fdec5078d 100644 --- a/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp +++ b/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp @@ -1,13 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + #include #include "test_nrm2.h" template -class OUT_nrm2 : public ::testing::Test {}; +class nrm2UOT : public ::testing::Test {}; typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(OUT_nrm2, TypeParam); +TYPED_TEST_SUITE(nrm2UOT, TypeParam); // Testing for max representable number to see if overflow is handled correctly. -TYPED_TEST(OUT_nrm2, maxFP_scalar) { +TYPED_TEST(nrm2UOT, maxFP_scalar) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; @@ -15,9 +49,9 @@ TYPED_TEST(OUT_nrm2, maxFP_scalar) { T x = T{maxval}; RT norm = nrm2(1, &x, 1); - computediff(maxval, norm); + computediff("norm", norm, maxval); } -TYPED_TEST(OUT_nrm2, maxFP_vectorized) { +TYPED_TEST(nrm2UOT, maxFP_vectorized) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 64; @@ -25,20 +59,20 @@ TYPED_TEST(OUT_nrm2, maxFP_vectorized) { RT maxval = (std::numeric_limits::max)(); x[17] = T{maxval}; RT norm = nrm2(n, x.data(), 1); - computediff(maxval, norm); + computediff("norm", norm, maxval); } // Testing for min representable number to see if underflow is handled correctly. -TYPED_TEST(OUT_nrm2, minFP_scalar) { +TYPED_TEST(nrm2UOT, minFP_scalar) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; RT minval = (std::numeric_limits::min)(); T x = T{minval}; RT norm = nrm2(1, &x, 1); - computediff(minval, norm); + computediff("norm", norm, minval); } -TYPED_TEST(OUT_nrm2, minFP_vectorized) { +TYPED_TEST(nrm2UOT, minFP_vectorized) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 64; @@ -46,27 +80,27 @@ TYPED_TEST(OUT_nrm2, minFP_vectorized) { RT minval = (std::numeric_limits::min)(); x[17] = T{minval}; RT norm = nrm2(n, x.data(), 1); - computediff(minval, norm); + computediff("norm", norm, minval); } // Since there are 2 different paths, vectorized and scalar, // we break this into 2 tests, once for each case. -TYPED_TEST(OUT_nrm2, zeroFP_scalar) { +TYPED_TEST(nrm2UOT, zeroFP_scalar) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; T x = T{0}; RT norm = nrm2(1, &x, 1); - computediff(0, norm); + computediff("norm", norm, 0); } -TYPED_TEST(OUT_nrm2, zeroFP_vectorized) { +TYPED_TEST(nrm2UOT, zeroFP_vectorized) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 64; std::vector x(n, T{0}); RT norm = nrm2(n, x.data(), 1); - computediff(0, norm); + computediff("norm", norm, 0); } /* @@ -77,7 +111,7 @@ TYPED_TEST(OUT_nrm2, zeroFP_vectorized) { */ // Checking only for overflow, based on the threshold -TYPED_TEST( OUT_nrm2, OFlow_MT ) { +TYPED_TEST( nrm2UOT, OFlow_MT ) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 2950000; @@ -101,11 +135,11 @@ TYPED_TEST( OUT_nrm2, OFlow_MT ) { RT norm = nrm2( n, x.data(), 1 ); RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); - computediff( norm, ref_norm, thresh ); + computediff( "norm", norm, ref_norm, thresh ); } // Checking only for underflow, based on the threshold -TYPED_TEST( OUT_nrm2, UFlow_MT ) { +TYPED_TEST( nrm2UOT, UFlow_MT ) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 2950000; @@ -129,11 +163,11 @@ TYPED_TEST( OUT_nrm2, UFlow_MT ) { RT norm = nrm2( n, x.data(), 1 ); RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); - computediff( norm, ref_norm, thresh ); + computediff( "norm", norm, ref_norm, thresh ); } // Checking for both overflow and underflow, based on the thresholds -TYPED_TEST( OUT_nrm2, OUFlow_MT ) { +TYPED_TEST( nrm2UOT, OUFlow_MT ) { using T = TypeParam; using RT = typename testinghelpers::type_info::real_type; gtint_t n = 2950000; @@ -159,7 +193,7 @@ TYPED_TEST( OUT_nrm2, OUFlow_MT ) { RT norm = nrm2( n, x.data(), 1 ); RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); - computediff( norm, ref_norm, thresh ); + computediff( "norm", norm, ref_norm, thresh ); } // Specific test case used by an ISV. @@ -170,8 +204,8 @@ TEST(dnrm2, largeDouble) { std::vector x{3e300, 4e300}, y{-4e300, -3e300}; T norm = nrm2(n, x.data(), 1); - computediff(5e300, norm); + computediff( "norm", norm, 5e300 ); norm = nrm2(n, y.data(), 1); - computediff(5e300, norm); + computediff( "norm", norm, 5e300 ); } diff --git a/gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/scnrm2_evt.cpp similarity index 67% rename from gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp rename to gtestsuite/testsuite/util/nrm2/scnrm2_evt.cpp index 52ba4f8647..d7331e7c90 100644 --- a/gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp +++ b/gtestsuite/testsuite/util/nrm2/scnrm2_evt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class scnrm2_EVT : +class scnrm2EVT : public ::testing::TestWithParam>{}; -TEST_P( scnrm2_EVT, EVT ) +TEST_P( scnrm2EVT, API ) { using T = scomplex; //---------------------------------------------------------- @@ -62,41 +62,6 @@ TEST_P( scnrm2_EVT, EVT ) test_nrm2(n, incx, i, iexval, j, jexval); } -// Prints the test case combination -class scnrm2_TestPrint{ -public: - std::string operator()( - testing::TestParamInfo> str) const { - // vector length: - gtint_t n = std::get<0>(str.param); - // stride size for x: - gtint_t incx = std::get<1>(str.param); - // index with extreme value iexval. - gtint_t i = std::get<2>(str.param); - scomplex iexval = std::get<3>(str.param); - // index with extreme value jexval. - gtint_t j = std::get<4>(str.param); - scomplex jexval = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "scnrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_scnrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cnormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_i" + std::to_string(i); - std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag); - str_name = str_name + iexval_str; - str_name = str_name + "_j" + std::to_string(j); - std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag); - str_name = str_name + jexval_str; - return str_name; - } -}; - static float NaN = std::numeric_limits::quiet_NaN(); static float Inf = std::numeric_limits::infinity(); /** @@ -114,98 +79,98 @@ static float Inf = std::numeric_limits::infinity(); // of having first a NaN and then an Inf and so on. INSTANTIATE_TEST_SUITE_P( scalar, - scnrm2_EVT, + scnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(2)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(0), // iexval ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), ::testing::Values(1), ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) ), - ::scnrm2_TestPrint() + ::nrm2EVTPrint() ); INSTANTIATE_TEST_SUITE_P( vector_F16, - scnrm2_EVT, + scnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(64)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(10), // iexval ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), ::testing::Values(30), ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) ), - ::scnrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F12), we use n = 76 = 4*16+12 // and ensure that the extreme values are on or after index 64. INSTANTIATE_TEST_SUITE_P( vector_F12, - scnrm2_EVT, + scnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(76)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(68), // iexval ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), ::testing::Values(70), ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) ), - ::scnrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F8), we use n = 72 = 4*16+8 // and ensure that the extreme values are on or after index 64. INSTANTIATE_TEST_SUITE_P( vector_F8, - scnrm2_EVT, + scnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(72)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(66), // iexval ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), ::testing::Values(70), ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) ), - ::scnrm2_TestPrint() + ::nrm2EVTPrint() ); -// Now let's check the combination of a vectorized path and +// Now let's check the combination of a vectorized path and // the scalar path, by putting an extreme value in each // to check that the checks are integrated correctly. INSTANTIATE_TEST_SUITE_P( vector_scalar, - scnrm2_EVT, + scnrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(79)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(25), // iexval ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), ::testing::Values(68), ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) ), - ::scnrm2_TestPrint() + ::nrm2EVTPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp index d27f5c50b5..1838085dcb 100644 --- a/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class scnrm2Test : +class scnrm2Generic : public ::testing::TestWithParam> {}; -TEST_P( scnrm2Test, RandomData ) +TEST_P( scnrm2Generic, API ) { using T = scomplex; //---------------------------------------------------------- @@ -51,7 +51,15 @@ TEST_P( scnrm2Test, RandomData ) gtint_t incx = std::get<1>(GetParam()); // Set the threshold for the errors: - double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + // No adjustment applied yet for complex data. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = std::sqrt(n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -59,27 +67,6 @@ TEST_P( scnrm2Test, RandomData ) test_nrm2(n, incx, thresh); } -// Prints the test case combination -class scnrm2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "scnrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_scnrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_cnormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - return str_name; - } -}; - /** * scnrm2 implementation is composed by two parts: * - vectorized path for n>=64 @@ -90,7 +77,7 @@ class scnrm2TestPrint { */ INSTANTIATE_TEST_SUITE_P( AT, - scnrm2Test, + scnrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1), // trivial case n=1 @@ -112,5 +99,5 @@ INSTANTIATE_TEST_SUITE_P( #endif ) ), - ::scnrm2TestPrint() + ::nrm2GenericPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/snrm2_evt.cpp similarity index 65% rename from gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp rename to gtestsuite/testsuite/util/nrm2/snrm2_evt.cpp index 5bfa83a346..9f4b9a3f2c 100644 --- a/gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp +++ b/gtestsuite/testsuite/util/nrm2/snrm2_evt.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class snrm2_EVT : +class snrm2EVT : public ::testing::TestWithParam> {}; -TEST_P( snrm2_EVT, EVT ) +TEST_P( snrm2EVT, API ) { using T = float; //---------------------------------------------------------- @@ -62,48 +62,13 @@ TEST_P( snrm2_EVT, EVT ) test_nrm2(n, incx, i, iexval, j, jexval); } -// Prints the test case combination -class snrm2_TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - // vector length: - gtint_t n = std::get<0>(str.param); - // stride size for x: - gtint_t incx = std::get<1>(str.param); - // index with extreme value iexval. - gtint_t i = std::get<2>(str.param); - float iexval = std::get<3>(str.param); - // index with extreme value jexval. - gtint_t j = std::get<4>(str.param); - float jexval = std::get<5>(str.param); -#ifdef TEST_BLAS - std::string str_name = "snrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_snrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_snormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_i" + std::to_string(i); - std::string iexval_str = testinghelpers::get_value_string(iexval); - str_name = str_name + "_" + iexval_str; - str_name = str_name + "_j" + std::to_string(j); - std::string jexval_str = testinghelpers::get_value_string(jexval); - str_name = str_name + "_" + jexval_str; - return str_name; - } -}; - static float NaN = std::numeric_limits::quiet_NaN(); static float Inf = std::numeric_limits::infinity(); /** - * Note: snrm2 scalar ONLY implementation is used, but we write the test + * Note: snrm2 scalar ONLY implementation is used, but we write the test * using values that worked for the vectorized path for the future. - * + * * scnrm2 implementation is composed by two parts: * - vectorized path for n>=64 * - for-loop for multiples of 32 (F32) @@ -118,98 +83,98 @@ static float Inf = std::numeric_limits::infinity(); // of having first a NaN and then an Inf and so on. INSTANTIATE_TEST_SUITE_P( scalar, - snrm2_EVT, + snrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(3)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(0), // iexval ::testing::Values(NaN, Inf, -Inf), ::testing::Values(2), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::snrm2_TestPrint() + ::nrm2EVTPrint() ); INSTANTIATE_TEST_SUITE_P( vector_F32, - snrm2_EVT, + snrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(64)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(13), // iexval ::testing::Values(NaN, Inf, -Inf), ::testing::Values(26), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::snrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F24), we use n = 88 = 2*32+24 // and ensure that the extreme values are on or after index 64. INSTANTIATE_TEST_SUITE_P( vector_F24, - snrm2_EVT, + snrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(88)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(70), // iexval ::testing::Values(NaN, Inf, -Inf), ::testing::Values(80), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::snrm2_TestPrint() + ::nrm2EVTPrint() ); // To test the second for-loop (F16), we use n = 80 = 2*32+16 // and ensure that the extreme values are on or after index 64. INSTANTIATE_TEST_SUITE_P( vector_F16, - snrm2_EVT, + snrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(80)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(70), // iexval ::testing::Values(NaN, Inf, -Inf), ::testing::Values(75), ::testing::Values(1.0, NaN, Inf, -Inf) ), - ::snrm2_TestPrint() + ::nrm2EVTPrint() ); -// Now let's check the combination of a vectorized path and +// Now let's check the combination of a vectorized path and // the scalar path, by putting an extreme value in each // to check that the checks are integrated correctly. INSTANTIATE_TEST_SUITE_P( vector_scalar, - snrm2_EVT, + snrm2EVT, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(68)), // stride size for x ::testing::Values(gtint_t(1)), - // i : index of x that has value iexval + // i : index of x that has value iexval ::testing::Values(5), // iexval ::testing::Values(NaN, Inf, -Inf), ::testing::Values(65), ::testing::Values(NaN, Inf, -Inf) ), - ::snrm2_TestPrint() + ::nrm2EVTPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp index eac411d12d..acd2f4bb71 100644 --- a/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class snrm2Test : +class snrm2Generic : public ::testing::TestWithParam> {}; -TEST_P( snrm2Test, RandomData ) +TEST_P( snrm2Generic, API ) { using T = float; //---------------------------------------------------------- @@ -51,7 +51,14 @@ TEST_P( snrm2Test, RandomData ) gtint_t incx = std::get<1>(GetParam()); // Set the threshold for the errors: - double thresh = 2*n*testinghelpers::getEpsilon(); + // Check gtestsuite asumv.h or netlib source code for reminder of the + // functionality from which we estimate operation count per element + // of output, and hence the multipler for epsilon. + double thresh; + if (n == 0) + thresh = 0.0; + else + thresh = std::sqrt(n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters @@ -59,31 +66,10 @@ TEST_P( snrm2Test, RandomData ) test_nrm2( n, incx, thresh ); } -// Prints the test case combination -class snrm2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); -#ifdef TEST_BLAS - std::string str_name = "snrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_snrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_snormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - return str_name; - } -}; - /** - * Note: snrm2 scalar ONLY implementation is used, but we write the test + * Note: snrm2 scalar ONLY implementation is used, but we write the test * using values that worked for the vectorized path for the future. - * + * * scnrm2 implementation is composed by two parts: * - vectorized path for n>=64 * - for-loop for multiples of 32 (F32) @@ -93,7 +79,7 @@ class snrm2TestPrint { */ INSTANTIATE_TEST_SUITE_P( AT, - snrm2Test, + snrm2Generic, ::testing::Combine( // m size of vector ::testing::Values(gtint_t(1), // trivial case n=1 @@ -115,5 +101,5 @@ INSTANTIATE_TEST_SUITE_P( #endif ) // stride size for x ), - ::snrm2TestPrint() + ::nrm2GenericPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/test_nrm2.h b/gtestsuite/testsuite/util/nrm2/test_nrm2.h index def4551929..48e33e99c2 100644 --- a/gtestsuite/testsuite/util/nrm2/test_nrm2.h +++ b/gtestsuite/testsuite/util/nrm2/test_nrm2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -49,7 +49,7 @@ void test_nrm2( gtint_t n, gtint_t incx, double thresh ) // Initialize vectors with random numbers. //---------------------------------------------------------- std::vector x = testinghelpers::get_random_vector( -10, -10, n, incx ); - + //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- @@ -63,7 +63,7 @@ void test_nrm2( gtint_t n, gtint_t incx, double thresh ) //---------------------------------------------------------- // Compute error. //---------------------------------------------------------- - computediff( norm, norm_ref, thresh ); + computediff( "norm", norm, norm_ref, thresh ); } // Test body used for extreme value testing, where we want to test @@ -97,5 +97,51 @@ void test_nrm2( gtint_t n, gtint_t incx, gtint_t i, T iexval, gtint_t j = 0, T j // Compute error. //---------------------------------------------------------- // Compare using NaN/Inf checks. - computediff( norm, norm_ref, true ); + computediff( "norm", norm, norm_ref, true ); } + +// Test-case logger : Used to print the test-case details based on parameters +class nrm2GenericPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + return str_name; + } +}; + + +// Test-case logger : Used to print the test-case details based on parameters +template +class nrm2EVTPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + // vector length: + gtint_t n = std::get<0>(str.param); + // stride size for x: + gtint_t incx = std::get<1>(str.param); + // index with extreme value iexval. + gtint_t i = std::get<2>(str.param); + T iexval = std::get<3>(str.param); + // index with extreme value jexval. + gtint_t j = std::get<4>(str.param); + T jexval = std::get<5>(str.param); + + std::string str_name = API_PRINT; + str_name += "_n_" + std::to_string(n); + str_name += "_incx_" + testinghelpers::get_value_string(incx); + str_name = str_name + "_i" + std::to_string(i); + std::string iexval_str = testinghelpers::get_value_string(iexval); + str_name = str_name + "_" + iexval_str; + str_name = str_name + "_j" + std::to_string(j); + std::string jexval_str = testinghelpers::get_value_string(jexval); + str_name = str_name + "_" + jexval_str; + return str_name; + } +}; diff --git a/kernels/CMakeLists.txt b/kernels/CMakeLists.txt index fa15654125..89aef2bd15 100644 --- a/kernels/CMakeLists.txt +++ b/kernels/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Writing a function that will be used to generate the required object # libraries for the required kernels. @@ -9,6 +41,11 @@ function(generate_kernel_targets kernel_target) # Choose correct sub-configurarion name for the given kernel set. get_config_for_kernel_from_kconfig_map(LOCAL_CONFIG ${kernel_target} "${KCONFIG_MAP}") + # filter the lpgemm source files to a different array + set(LOCAL_LPEGMM_SOURCE_FILES ${LOCAL_SOURCE_FILES}) + list(FILTER LOCAL_SOURCE_FILES EXCLUDE REGEX ".*/lpgemm/.*") + list(FILTER LOCAL_LPEGMM_SOURCE_FILES INCLUDE REGEX ".*/lpgemm/.*") + # Only generate the object library if there is at least one source file. list(LENGTH LOCAL_SOURCE_FILES size) if(size GREATER 0) @@ -19,7 +56,8 @@ function(generate_kernel_targets kernel_target) ) # Include the corresponding make_defs.cmake that holds the required compiler options. include(${CMAKE_SOURCE_DIR}/config/${LOCAL_CONFIG}/make_defs.cmake) - # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # Use PRIVATE keyword for option setting since we do not want the + # properties to propagate in other targets. # mimicing get-kernel-cflags-for target_compile_options(${kernel_target}_KERNELS PRIVATE @@ -62,15 +100,77 @@ function(generate_kernel_targets kernel_target) # in get-noopt-cflags-for target_compile_options(${kernel_target}_KERNELS PRIVATE ${CTHREADFLAGS}) endif() - if(BUILD_SHARED_LIBS) - # Equivalent to CPICFLAGS in get-noopt-cflags-for - set_target_properties(${kernel_target}_KERNELS PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${kernel_target}_KERNELS PROPERTIES POSITION_INDEPENDENT_CODE ON) add_dependencies(${kernel_target}_KERNELS flat-header) # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. set_target_properties(${kernel_target}_KERNELS PROPERTIES FOLDER object-libs-targets) endif() -endfunction() + + # Only generate the object library if there is at least one source file. + list(LENGTH LOCAL_LPEGMM_SOURCE_FILES size_lpgemm) + if (size_lpgemm GREATER 0) + # Create an object library using the source file list above. + add_library(${kernel_target}_LPGEMM_KERNELS + OBJECT + ${LOCAL_LPEGMM_SOURCE_FILES} + ) + # Include the corresponding make_defs.cmake that holds the required compiler options. + include(${CMAKE_SOURCE_DIR}/config/${LOCAL_CONFIG}/make_defs.cmake) + # Use PRIVATE keyword for option setting since we do not want the + # properties to propagate in other targets. + # mimicing get-kernel-cflags-for + target_compile_options(${kernel_target}_LPGEMM_KERNELS + PRIVATE + # load-var-for,CKOPTFLAGS + ${CKOPTFLAGS} + # load-var-for,CKLPOPTFLAGS + ${CKLPOPTFLAGS} + # load-var-for,CKVECFLAGS + ${CKVECFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-kernel-cflags-for + ${COMPSIMDFLAGS} + # in get-kernel-cflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${kernel_target}_LPGEMM_KERNELS + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-kernel-cflags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${kernel_target}_LPGEMM_KERNELS + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${kernel_target}_LPGEMM_KERNELS PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${kernel_target}_LPGEMM_KERNELS PRIVATE ${CTHREADFLAGS}) + endif() + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${kernel_target}_LPGEMM_KERNELS PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_dependencies(${kernel_target}_LPGEMM_KERNELS flat-header) + # Put all those targets under object-libs-targets folder name so that they appear + # all together in IDE. + set_target_properties(${kernel_target}_LPGEMM_KERNELS PROPERTIES FOLDER object-libs-targets) + endif() + endfunction() # Generate targets for each of the kernels present # in the kernel list. diff --git a/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c b/kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c similarity index 98% rename from kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c rename to kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c index a9b3d0af8a..85dfaa9c0e 100644 --- a/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c +++ b/kernels/armsve/1m/bli_dpackm_armsve256_int_8xk.c @@ -35,17 +35,14 @@ #include "blis.h" -#ifdef __ARM_FEATURE_SVE +#if (defined(BLIS_FAMILY_ARMSVE) && !defined(BLIS_FAMILY_A64FX)) #include -#else -#error "No Arm SVE intrinsics support in compiler" -#endif // __ARM_FEATURE_SVE // assumption: // SVE vector length = 256 bits. // -void bli_dpackm_armsve256_asm_8xk +void bli_dpackm_armsve256_int_8xk ( conj_t conja, pack_t schema, @@ -230,3 +227,5 @@ void bli_dpackm_armsve256_asm_8xk ); } } + +#endif // __has_include() diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c index 851363a9e0..44718fa578 100644 --- a/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c @@ -64,12 +64,11 @@ void bli_dpackm_armsve512_asm_10xk const bool unitk = bli_deq1( *kappa ); #ifdef _A64FX - if ( bli_cntx_schema_a_block(cntx) != bli_cntx_schema_b_panel(cntx) ) { - // A twisted way to infer whether A or B is being packed. - if ( schema == bli_cntx_schema_a_block(cntx) ) + // Infer whether A or B is being packed. + if ( schema == BLIS_PACKED_ROWS ) p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; - if ( schema == bli_cntx_schema_b_panel(cntx) ) + if ( schema == BLIS_PACKED_COLUMNS ) p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; } #endif diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c index 38fb0b9125..f02b87a7a0 100644 --- a/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c @@ -63,12 +63,11 @@ void bli_dpackm_armsve512_asm_16xk const bool unitk = bli_deq1( *kappa ); #ifdef _A64FX - if ( bli_cntx_schema_a_block(cntx) != bli_cntx_schema_b_panel(cntx) ) { - // A twisted way to infer whether A or B is being packed. - if ( schema == bli_cntx_schema_a_block(cntx) ) + // Infer whether A or B is being packed. + if ( schema == BLIS_PACKED_ROWS ) p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; - if ( schema == bli_cntx_schema_b_panel(cntx) ) + if ( schema == BLIS_PACKED_COLUMNS ) p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; } #endif diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c b/kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c similarity index 98% rename from kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c rename to kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c index 9f943fcd66..966b0c134f 100644 --- a/kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c +++ b/kernels/armsve/1m/old/bli_dpackm_armsve512_int_12xk.c @@ -36,11 +36,8 @@ #include "blis.h" #include -#ifdef __ARM_FEATURE_SVE +#if (defined(BLIS_FAMILY_ARMSVE) && !defined(BLIS_FAMILY_A64FX)) #include -#else -#error "No Arm SVE intrinsics support in compiler" -#endif // __ARM_FEATURE_SVE // assumption: // SVE vector length = 512 bits. @@ -48,7 +45,7 @@ // 2-rows -> 3 vectors packing and use predicator only in odd num of rows to be packed. // prefetching is needed. -void bli_dpackm_armsve512_asm_12xk +void bli_dpackm_armsve512_int_12xk ( conj_t conja, pack_t schema, @@ -357,3 +354,5 @@ void bli_dpackm_armsve512_asm_12xk ); } } + +#endif // __has_include() diff --git a/kernels/armsve/3/armsve_asm_2vx10.h b/kernels/armsve/3/armsve_asm_2vx10.h index 8e37585cba..ae89fa1ece 100644 --- a/kernels/armsve/3/armsve_asm_2vx10.h +++ b/kernels/armsve/3/armsve_asm_2vx10.h @@ -130,6 +130,13 @@ SCALE_COL4(Z12,Z13,Z14,Z15,ZFACTOR) \ SCALE_COL4(Z16,Z17,Z18,Z19,ZFACTOR) +#define GEMM_C_FMLA_UKER(C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,PT,Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZSCALE) \ + GEMM_FMLA2(C0FH,C0LH,PT,Z0FH,Z0LH,ZSCALE) \ + GEMM_FMLA2(C1FH,C1LH,PT,Z1FH,Z1LH,ZSCALE) \ + GEMM_FMLA2(C2FH,C2LH,PT,Z2FH,Z2LH,ZSCALE) \ + GEMM_FMLA2(C3FH,C3LH,PT,Z3FH,Z3LH,ZSCALE) \ + GEMM_FMLA2(C4FH,C4LH,PT,Z4FH,Z4LH,ZSCALE) + #define GEMM_C_FMAD_UKER(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE) \ GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ diff --git a/kernels/armsve/3/armsve_asm_2vx10cmplx.h b/kernels/armsve/3/armsve_asm_2vx10cmplx.h new file mode 100644 index 0000000000..1b67d0d169 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx10cmplx.h @@ -0,0 +1,130 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,16) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,18) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,1) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,3) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,5) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,7) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,9) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,11) \ + GEMM_FMLA2_LD1R(C8Re,C8Im,PT,AColRe,AColIm,BV0,BAddr,13) \ + GEMM_FMLA2_LD1R(C9Re,C9Im,PT,AColRe,AColIm,BV1,BAddr,15) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV2,BAddr,17) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV3,BAddr,19) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,BV4,BAddr,0) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,BV5,BAddr,2) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,BV6,BAddr,4) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,BV7,BAddr,6) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,BV0,BAddr,8) \ + GEMM_FMLX2_LD1R(C7Im,C7Re,PT,AColRe,AColIm,BV1,BAddr,10) \ + GEMM_FMLX2_LD1R(C8Im,C8Re,PT,AColRe,AColIm,BV2,BAddr,12) \ + GEMM_FMLX2_LD1R(C9Im,C9Re,PT,AColRe,AColIm,BV3,BAddr,14) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,16) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,18) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,1) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,3) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,5) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,7) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,9) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,11) \ + GEMM_FMLA2_LD1R(C8Re,C8Im,PT,AColRe,AColIm,BV0,BAddr,13) \ + GEMM_FMLA2_LD1R(C9Re,C9Im,PT,AColRe,AColIm,BV1,BAddr,15) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV2,BAddr,17) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV3,BAddr,19) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,BV4) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,BV5) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,BV6) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,BV7) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,BV0) \ + GEMM_FMLX2(C7Im,C7Re,PT,AColRe,AColIm,BV1) \ + GEMM_FMLX2(C8Im,C8Re,PT,AColRe,AColIm,BV2) \ + GEMM_FMLX2(C9Im,C9Re,PT,AColRe,AColIm,BV3) + +#define GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) \ + GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C8Re,C9Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,C8Im,C9Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define CLEAR_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z16,Z17,Z18,Z19) + +// Moving is always .d. +// Never use .DT here! +#define MOV_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,Z0Re,Z0Im,Z1Re,Z1Im) \ +" mov "#ZD0Re".d, "#Z0Re".d \n\t" \ +" mov "#ZD0Im".d, "#Z0Im".d \n\t" \ +" mov "#ZD1Re".d, "#Z1Re".d \n\t" \ +" mov "#ZD1Im".d, "#Z1Im".d \n\t" + +#define GEMM_FMULCMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/armsve_asm_2vx7cmplx.h b/kernels/armsve/3/armsve_asm_2vx7cmplx.h new file mode 100644 index 0000000000..43997deef4 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx7cmplx.h @@ -0,0 +1,135 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,PT,AColRe,AColIm,B0Re,B1Re,B2Re,B3Re,B4Re,B5Re,B6Re,B0Im,B1Im,B2Im,B3Im,B4Im,B5Im,B6Im,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,B0Re,BAddr,0) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,B1Re,BAddr,2) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,B2Re,BAddr,4) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,B3Re,BAddr,6) \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,B4Re,BAddr,8) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,B5Re,BAddr,10) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,B6Re,BAddr,12) \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,B0Im,BAddr,1) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,B1Im,BAddr,3) \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,B2Im,BAddr,5) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,B3Im,BAddr,7) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,B4Im,BAddr,9) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,B5Im,BAddr,11) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,B6Im,BAddr,13) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" + +#define GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,PT,AColRe,AColIm,B0Re,B1Re,B2Re,B3Re,B4Re,B5Re,B6Re,B0Im,B1Im,B2Im,B3Im,B4Im,B5Im,B6Im,BAddr,BRSBit) \ + GEMM_FMLA2(C0Re,C0Im,PT,AColRe,AColIm,B0Re) \ + GEMM_FMLA2(C1Re,C1Im,PT,AColRe,AColIm,B1Re) \ + GEMM_FMLA2(C2Re,C2Im,PT,AColRe,AColIm,B2Re) \ + GEMM_FMLA2(C3Re,C3Im,PT,AColRe,AColIm,B3Re) \ + GEMM_FMLA2(C4Re,C4Im,PT,AColRe,AColIm,B4Re) \ + GEMM_FMLA2(C5Re,C5Im,PT,AColRe,AColIm,B5Re) \ + GEMM_FMLA2(C6Re,C6Im,PT,AColRe,AColIm,B6Re) \ + GEMM_FMLX2(C0Im,C0Re,PT,AColRe,AColIm,B0Im) \ + GEMM_FMLX2(C1Im,C1Re,PT,AColRe,AColIm,B1Im) \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,B2Im) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,B3Im) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,B4Im) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,B5Im) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,B6Im) + +#define CLEAR_COL14(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL2(Z12,Z13) + +#define GEMM_FMULCMPLX_COL7(ZD0Re,ZD0Im,ZD1Re,ZD1Im,ZD2Re,ZD2Im,ZD3Re,ZD3Im,ZD4Re,ZD4Im,ZD5Re,ZD5Im,ZD6Re,ZD6Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + FMUL_COL2(ZD2Re,ZD2Im,Z2Re,Z2Im,ZFactorRe) \ + FMUL_COL2(ZD3Re,ZD3Im,Z3Re,Z3Im,ZFactorRe) \ + FMUL_COL2(ZD4Re,ZD4Im,Z4Re,Z4Im,ZFactorRe) \ + FMUL_COL2(ZD5Re,ZD5Im,Z5Re,Z5Im,ZFactorRe) \ + FMUL_COL2(ZD6Re,ZD6Im,Z6Re,Z6Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) \ + GEMM_FMLX2(ZD2Im,ZD2Re,PT,Z2Re,Z2Im,ZFactorIm) \ + GEMM_FMLX2(ZD3Im,ZD3Re,PT,Z3Re,Z3Im,ZFactorIm) \ + GEMM_FMLX2(ZD4Im,ZD4Re,PT,Z4Re,Z4Im,ZFactorIm) \ + GEMM_FMLX2(ZD5Im,ZD5Re,PT,Z5Re,Z5Im,ZFactorIm) \ + GEMM_FMLX2(ZD6Im,ZD6Re,PT,Z6Re,Z6Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL7(ZD0Re,ZD0Im,ZD1Re,ZD1Im,ZD2Re,ZD2Im,ZD3Re,ZD3Im,ZD4Re,ZD4Im,ZD5Re,ZD5Im,ZD6Re,ZD6Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD2Re,ZD2Im,PT,Z2Re,Z2Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD3Re,ZD3Im,PT,Z3Re,Z3Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD4Re,ZD4Im,PT,Z4Re,Z4Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD5Re,ZD5Im,PT,Z5Re,Z5Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD6Re,ZD6Im,PT,Z6Re,Z6Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL7_C(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z2Re,Z2Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z3Re,Z3Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z4Re,Z4Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z5Re,Z5Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z6Re,Z6Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL7_C(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z2Re,Z2Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z3Re,Z3Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z4Re,Z4Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z5Re,Z5Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z6Re,Z6Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL7_G(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z2Re,Z2Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z3Re,Z3Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z4Re,Z4Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z5Re,Z5Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z6Re,Z6Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL7_G(Z0Re,Z0Im,Z1Re,Z1Im,Z2Re,Z2Im,Z3Re,Z3Im,Z4Re,Z4Im,Z5Re,Z5Im,Z6Re,Z6Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z2Re,Z2Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z3Re,Z3Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z4Re,Z4Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z5Re,Z5Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z6Re,Z6Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/armsve_asm_2vx8cmplx.h b/kernels/armsve/3/armsve_asm_2vx8cmplx.h new file mode 100644 index 0000000000..16711930a4 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx8cmplx.h @@ -0,0 +1,116 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,9) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,11) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,13) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,15) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLA2_LD1R(C4Re,C4Im,PT,AColRe,AColIm,BV4,BAddr,0) \ + GEMM_FMLA2_LD1R(C5Re,C5Im,PT,AColRe,AColIm,BV5,BAddr,2) \ + GEMM_FMLA2_LD1R(C6Re,C6Im,PT,AColRe,AColIm,BV6,BAddr,4) \ + GEMM_FMLA2_LD1R(C7Re,C7Im,PT,AColRe,AColIm,BV7,BAddr,6) \ + \ + GEMM_FMLX2_LD1R(C0Im,C0Re,PT,AColRe,AColIm,BV8,BAddr,8) \ + GEMM_FMLX2_LD1R(C1Im,C1Re,PT,AColRe,AColIm,BV9,BAddr,10) \ + GEMM_FMLX2_LD1R(C2Im,C2Re,PT,AColRe,AColIm,BV10,BAddr,12) \ + GEMM_FMLX2_LD1R(C3Im,C3Re,PT,AColRe,AColIm,BV11,BAddr,14) \ + GEMM_FMLX2_LD1R(C4Im,C4Re,PT,AColRe,AColIm,BV0,BAddr,1) \ + GEMM_FMLX2_LD1R(C5Im,C5Re,PT,AColRe,AColIm,BV1,BAddr,3) \ + GEMM_FMLX2_LD1R(C6Im,C6Re,PT,AColRe,AColIm,BV2,BAddr,5) \ + GEMM_FMLX2_LD1R(C7Im,C7Re,PT,AColRe,AColIm,BV3,BAddr,7) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BAddr,BRSBit) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_FMLA2_LD1R(C0Re,C0Im,PT,AColRe,AColIm,BV0,BAddr,9) \ + GEMM_FMLA2_LD1R(C1Re,C1Im,PT,AColRe,AColIm,BV1,BAddr,11) \ + GEMM_FMLA2_LD1R(C2Re,C2Im,PT,AColRe,AColIm,BV2,BAddr,13) \ + GEMM_FMLA2_LD1R(C3Re,C3Im,PT,AColRe,AColIm,BV3,BAddr,15) \ +" add "#BAddr", "#BRSBit", "#BAddr" \n\t" /* B address forward */ \ + GEMM_FMLA2(C4Re,C4Im,PT,AColRe,AColIm,BV4) \ + GEMM_FMLA2(C5Re,C5Im,PT,AColRe,AColIm,BV5) \ + GEMM_FMLA2(C6Re,C6Im,PT,AColRe,AColIm,BV6) \ + GEMM_FMLA2(C7Re,C7Im,PT,AColRe,AColIm,BV7) \ + \ + GEMM_FMLX2(C0Im,C0Re,PT,AColRe,AColIm,BV8) \ + GEMM_FMLX2(C1Im,C1Re,PT,AColRe,AColIm,BV9) \ + GEMM_FMLX2(C2Im,C2Re,PT,AColRe,AColIm,BV10) \ + GEMM_FMLX2(C3Im,C3Re,PT,AColRe,AColIm,BV11) \ + GEMM_FMLX2(C4Im,C4Re,PT,AColRe,AColIm,BV0) \ + GEMM_FMLX2(C5Im,C5Re,PT,AColRe,AColIm,BV1) \ + GEMM_FMLX2(C6Im,C6Re,PT,AColRe,AColIm,BV2) \ + GEMM_FMLX2(C7Im,C7Re,PT,AColRe,AColIm,BV3) + +#define GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BV8,BV9,BV10,BV11,BAddr,BRSBit) \ + GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(C0Re,C1Re,C2Re,C3Re,C4Re,C5Re,C6Re,C7Re,C0Im,C1Im,C2Im,C3Im,C4Im,C5Im,C6Im,C7Im,PT,AColRe,AColIm,BV8,BV9,BV10,BV11,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BAddr,BRSBit) + +#define CLEAR_COL16(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) + +#define GEMM_FMULCMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + FMUL_COL2(ZD0Re,ZD0Im,Z0Re,Z0Im,ZFactorRe) \ + FMUL_COL2(ZD1Re,ZD1Im,Z1Re,Z1Im,ZFactorRe) \ + GEMM_FMLX2(ZD0Im,ZD0Re,PT,Z0Re,Z0Im,ZFactorIm) \ + GEMM_FMLX2(ZD1Im,ZD1Re,PT,Z1Re,Z1Im,ZFactorIm) + +#define GEMM_FMLACMPLX_COL2(ZD0Re,ZD0Im,ZD1Re,ZD1Im,PT,Z0Re,Z0Im,Z1Re,Z1Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD0Re,ZD0Im,PT,Z0Re,Z0Im,ZFactorRe,ZFactorIm) \ + GEMM_FMLACMPLX(ZD1Re,ZD1Im,PT,Z1Re,Z1Im,ZFactorRe,ZFactorIm) + +#define GEMM_CCMPLX_LOAD_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_STORE_COL2_C(Z0Re,Z0Im,Z1Re,Z1Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z0Re,Z0Im,PT,CAddr,CCS) \ + GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(Z1Re,Z1Im,PT,CAddr,CCS) + +#define GEMM_CCMPLX_LOAD_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_GATHER_LOAD_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + +#define GEMM_CCMPLX_STORE_COL2_G(Z0Re,Z0Im,Z1Re,Z1Im,PT,ZIndex,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z0Re,Z0Im,ZIndex,PT,PT,CAddr,CCS,CTemp) \ + GEMM_CCOLCMPLX_SCATTER_STORE_FWD(Z1Re,Z1Im,ZIndex,PT,PT,CAddr,CCS,CTemp) + diff --git a/kernels/armsve/3/armsve_asm_macros_cmplx.h b/kernels/armsve/3/armsve_asm_macros_cmplx.h new file mode 100644 index 0000000000..10097700c8 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_cmplx.h @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "armsve_asm_macros.h" + +#define FMUL_COL2(ZD0,ZD1,Z0,Z1,ZFACTOR) \ +" fmul "#ZD0"."DT", "#Z0"."DT", "#ZFACTOR"."DT" \n\t" \ +" fmul "#ZD1"."DT", "#Z1"."DT", "#ZFACTOR"."DT" \n\t" \ + +#define GEMM_FMLX2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" fmla "#CCOLFH"."DT", "#PT"/m, "#ACOLFH"."DT", "#BV"."DT" \n\t" \ +" fmls "#CCOLLH"."DT", "#PT"/m, "#ACOLLH"."DT", "#BV"."DT" \n\t" + +#define GEMM_FMLX2_LD1R(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BADDR,NSHIFT) \ + GEMM_FMLX2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BADDR", #"#NSHIFT"*"SZ"]\n\t" + +#define GEMM_FMULCMPLX(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re,Z1Im) \ + FMUL_COL2(ZDRe,ZDIm,Z0Re,Z0Im,Z1Re) \ + GEMM_FMLX2(ZDIm,ZDRe,PT,Z0Re,Z0Im,Z1Im) + +#define GEMM_FMLACMPLX(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re,Z1Im) \ + GEMM_FMLA2(ZDRe,ZDIm,PT,Z0Re,Z0Im,Z1Re) \ + GEMM_FMLX2(ZDIm,ZDRe,PT,Z0Re,Z0Im,Z1Im) + +#define GEMM_ACOLCMPLX_CONTIGUOUS_LOAD(ZRe,ZIm,PT,AAddr) \ +" "LD2" {"#ZRe"."DT", "#ZIm"."DT"}, "#PT"/z, ["#AAddr"] \n\t" + +#define GEMM_ACOLCMPLX_CONTIGUOUS_STORE(ZRe,ZIm,PT,AAddr) \ +" "ST2" {"#ZRe"."DT", "#ZIm"."DT"}, "#PT", ["#AAddr"] \n\t" + +#define GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,AAddr,ACS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_LOAD(ZRe,ZIm,PT,AAddr) \ +" add "#AAddr", "#AAddr", "#ACS" \n\t" /* Forward A address (load) to next column. */ + +#define GEMM_CCOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,CAddr,CCS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(ZRe,ZIm,PT,CAddr,CCS) + +#define GEMM_ACOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,AAddr,ACS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_STORE(ZRe,ZIm,PT,AAddr) \ +" add "#AAddr", "#AAddr", "#ACS" \n\t" /* Forward A address (load) to next column. */ + +#define GEMM_CCOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,CAddr,CCS) \ + GEMM_ACOLCMPLX_CONTIGUOUS_STORE_FWD(ZRe,ZIm,PT,CAddr,CCS) + +#define GEMM_CCOLCMPLX_GATHER_LOAD_FWD(ZRe,ZIm,ZIndex,PRe,PIm,CAddr,CCS,CTemp) \ +" add "#CTemp", "#CAddr", #"SZ" \n\t" /* Imaginary skip */ \ +" "LD1" "#ZRe"."DT", "#PRe"/z, ["#CAddr", "#ZIndex"."DT", "OFFS"]\n\t" \ +" "LD1" "#ZIm"."DT", "#PRe"/z, ["#CTemp", "#ZIndex"."DT", "OFFS"]\n\t" \ +" add "#CAddr", "#CAddr", "#CCS" \n\t" + +#define GEMM_CCOLCMPLX_SCATTER_STORE_FWD(ZRe,ZIm,ZIndex,PRe,PIm,CAddr,CCS,CTemp) \ +" add "#CTemp", "#CAddr", #"SZ" \n\t" /* Imaginary skip */ \ +" "ST1" "#ZRe"."DT", "#PRe", ["#CAddr", "#ZIndex"."DT", "OFFS"]\n\t" \ +" "ST1" "#ZIm"."DT", "#PRe", ["#CTemp", "#ZIndex"."DT", "OFFS"]\n\t" \ +" add "#CAddr", "#CAddr", "#CCS" \n\t" + diff --git a/frame/3/bli_l3_tapi_ba.c b/kernels/armsve/3/armsve_asm_macros_dcomplex.h similarity index 83% rename from frame/3/bli_l3_tapi_ba.c rename to kernels/armsve/3/armsve_asm_macros_dcomplex.h index 748863f844..0beb5d2316 100644 --- a/frame/3/bli_l3_tapi_ba.c +++ b/kernels/armsve/3/armsve_asm_macros_dcomplex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,17 +31,18 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -#include "blis.h" - -// Include cpp macros that instantiate the API definition templates as -// omitting expert parameters. -#include "bli_tapi_ba.h" -// Define the macro protecting the typed API definitions. -#define BLIS_ENABLE_TAPI - -// Include the typed API definitions here. -#include "bli_l3_tapi.c" +*/ +// Specify to use double precision. +#define DT "d" +#define LD1 "ld1d" +#define ST1 "st1d" +#define LD2 "ld2d" +#define ST2 "st2d" +#define LD1R "ld1rd" +#define PRFG "prfd" +#define SZ "8" +#define OFFS "lsl #3" +// Include macros. +#include "armsve_asm_macros_cmplx.h" diff --git a/frame/3/bli_l3_oapi_ba.c b/kernels/armsve/3/armsve_asm_macros_scomplex.h similarity index 83% rename from frame/3/bli_l3_oapi_ba.c rename to kernels/armsve/3/armsve_asm_macros_scomplex.h index d6e3b2f3d5..f49cfedfba 100644 --- a/frame/3/bli_l3_oapi_ba.c +++ b/kernels/armsve/3/armsve_asm_macros_scomplex.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,17 +31,18 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -#include "blis.h" - -// Include cpp macros that instantiate the API definition templates as -// omitting expert parameters. -#include "bli_oapi_ba.h" -// Define the macro protecting the object API definitions. -#define BLIS_ENABLE_OAPI - -// Include the object API definitions here. -#include "bli_l3_oapi.c" +*/ +// Specify to use single precision. +#define DT "s" +#define LD1 "ld1w" +#define ST1 "st1w" +#define LD2 "ld2w" +#define ST2 "st2w" +#define LD1R "ld1rw" +#define PRFG "prfw" +#define SZ "4" +#define OFFS "uxtw #2" +// Include macros. +#include "armsve_asm_macros_cmplx.h" diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c new file mode 100644 index 0000000000..66337e0b73 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -0,0 +1,314 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Single-precision composite instructions. +#include "armsve_asm_macros_scomplex.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10cmplx.h" + +void bli_cgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, + scomplex* restrict b, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incw x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #8 \n\t" // Multiply some address skips by sizeof(scomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.s \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" ld1rw z20.s, p0/z, [%1, 4*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rw z21.s, p0/z, [%1, 4*2] \n\t" +" ld1rw z22.s, p0/z, [%1, 4*4] \n\t" +" ld1rw z23.s, p0/z, [%1, 4*6] \n\t" +" ld1rw z24.s, p0/z, [%1, 4*8] \n\t" +" ld1rw z25.s, p0/z, [%1, 4*10] \n\t" +" ld1rw z26.s, p0/z, [%1, 4*12] \n\t" +" ld1rw z27.s, p0/z, [%1, 4*14] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied. +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rw z20.s, p0/z, [%1, 4*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rw z21.s, p0/z, [%1, 4*2] \n\t" +" ld1rw z22.s, p0/z, [%1, 4*4] \n\t" +" ld1rw z23.s, p0/z, [%1, 4*6] \n\t" +" ld1rw z24.s, p0/z, [%1, 4*8] \n\t" +" ld1rw z25.s, p0/z, [%1, 4*10] \n\t" +" ld1rw z26.s, p0/z, [%1, 4*12] \n\t" +" ld1rw z27.s, p0/z, [%1, 4*14] \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rw z28.s, p0/z, [%7] \n\t" // Real(alpha). +" ld1rw z29.s, p0/z, [%7, 4] \n\t" // Imag(alpha). +" ld1rw z30.s, p0/z, [%8] \n\t" // Real(beta). +" ld1rw z31.s, p0/z, [%8, 4] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" fmov s27, #1.0 \n\t" +" fcmp s29, #0.0 \n\t" // Whether Imag(alpha) == 0. +" fccmp s28, s27, 0, eq \n\t" // Whether Real(alpha) == 1. +" b.eq UNIT_ALPHA \n\t" +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z28,z29) +GEMM_FMULCMPLX_COL2(z24,z25,z26,z27,p0,z4 ,z5 ,z6 ,z7 ,z28,z29) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z8, z9, z10,z11,z28,z29) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z28,z29) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z16,z17,z18,z19,z28,z29) +" b WRITE_MEM_EXEC \n\t" +" \n\t" +" UNIT_ALPHA: \n\t" +MOV_COL2(z20,z21,z22,z23,z0 ,z1 ,z2 ,z3 ) +MOV_COL2(z24,z25,z26,z27,z4 ,z5 ,z6 ,z7 ) +MOV_COL2(z0 ,z1 ,z2 ,z3 ,z8, z9, z10,z11) +MOV_COL2(z4 ,z5 ,z6 ,z7 ,z12,z13,z14,z15) +MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19) +" \n\t" +" WRITE_MEM_EXEC: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +" fmov s29, wzr \n\t" +" fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0. +" b.eq ZERO_BETA_C_0_1_2_3 \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +" ZERO_BETA_C_0_1_2_3: \n\t" +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z24,z25,z26,z27,p0,%2,%4) +" \n\t" +" b.eq ZERO_BETA_C_4_5_6_7_8_9 \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z20,z21,z22,z23,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +" ZERO_BETA_C_4_5_6_7_8_9: \n\t" +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" mov x3, %3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8. +" index z28.s, wzr, w3 \n\t" +" fmov s29, wzr \n\t" +" fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0. +" b.eq ZERO_BETA_G_0_1_2_3 \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +" ZERO_BETA_G_0_1_2_3: \n\t" +GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16) +" \n\t" +" b.eq ZERO_BETA_G_4_5_6_7_8_9 \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +" ZERO_BETA_G_4_5_6_7_8_9: \n\t" +GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c index 5824d2d550..e5b78a5921 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c @@ -264,12 +264,17 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1 " \n\t" " WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. " \n\t" // Here used scratch: Z[20-29]. +" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN. +" b.eq BETA_ZERO_C \n\t" // First half of C is already loaded in this case. -GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7) +// GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) " \n\t" -GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) -GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +" BETA_ZERO_C: \n\t" GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7) " b END_WRITE_MEM \n\t" " \n\t" " WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. @@ -278,13 +283,18 @@ GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) " incb x8 \n\t" " madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip. " index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +" \n\t" +" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN. +" b.eq BETA_ZERO_G \n\t" +" \n\t" +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) -GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) -GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) " \n\t" -GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x5,x7,x8,x16) -GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +" BETA_ZERO_G: \n\t" GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16) +GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x16) " \n\t" " END_WRITE_MEM: \n\t" " b END_EXEC \n\t" diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c index 8659e8b7ee..00b3f20b44 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -252,13 +252,16 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1 " \n\t" " WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. " \n\t" // Here used scratch: Z[20-29]. +" fcmp s31, #0.0 \n\t" +" b.eq BETA_ZERO_C \n\t" GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) -GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) -GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x9,x7) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) " \n\t" -GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) -GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +" BETA_ZERO_C: \n\t" GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7) " b END_WRITE_MEM \n\t" " \n\t" " WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. @@ -267,13 +270,17 @@ GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) " incb x8 \n\t" " madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip. " index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8. +" \n\t" +" fcmp s31, #0.0 \n\t" +" b.eq BETA_ZERO_G \n\t" +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) -GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) -GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31) " \n\t" -GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x5,x7,x8,x16) -GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +" BETA_ZERO_G: \n\t" GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16) +GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x16) " \n\t" " END_WRITE_MEM: \n\t" " b END_EXEC \n\t" diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c new file mode 100644 index 0000000000..2fa37664ae --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -0,0 +1,313 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10cmplx.h" + +void bli_zgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied. +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_2_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real 8/10, no imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +GEMM_2VX10CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z28.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z29.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z30.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z31.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" fmov d27, #1.0 \n\t" +" fcmp d29, #0.0 \n\t" // Whether Imag(alpha) == 0. +" fccmp d28, d27, 0, eq \n\t" // Whether Real(alpha) == 1. +" b.eq UNIT_ALPHA \n\t" +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z28,z29) +GEMM_FMULCMPLX_COL2(z24,z25,z26,z27,p0,z4 ,z5 ,z6 ,z7 ,z28,z29) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z8, z9, z10,z11,z28,z29) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z28,z29) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z16,z17,z18,z19,z28,z29) +" b WRITE_MEM_EXEC \n\t" +" \n\t" +" UNIT_ALPHA: \n\t" +MOV_COL2(z20,z21,z22,z23,z0 ,z1 ,z2 ,z3 ) +MOV_COL2(z24,z25,z26,z27,z4 ,z5 ,z6 ,z7 ) +MOV_COL2(z0 ,z1 ,z2 ,z3 ,z8, z9, z10,z11) +MOV_COL2(z4 ,z5 ,z6 ,z7 ,z12,z13,z14,z15) +MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19) +" \n\t" +" WRITE_MEM_EXEC: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +" fmov d29, xzr \n\t" +" fcmp d31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp d30, d29, 0, eq \n\t" // Whether Real(beta) == 0. +" b.eq ZERO_BETA_C_0_1_2_3 \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +" ZERO_BETA_C_0_1_2_3: \n\t" +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z24,z25,z26,z27,p0,%2,%4) +" \n\t" +" b.eq ZERO_BETA_C_4_5_6_7_8_9 \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z16,z17,z18,z19,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z20,z21,z22,z23,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +" ZERO_BETA_C_4_5_6_7_8_9: \n\t" +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" index z28.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +" fmov d29, xzr \n\t" +" fcmp d31, #0.0 \n\t" // Whether Imag(beta) == 0. +" fccmp d30, d29, 0, eq \n\t" // Whether Real(beta) == 0. +" b.eq ZERO_BETA_G_0_1_2_3 \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31) +" ZERO_BETA_G_0_1_2_3: \n\t" +GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16) +" \n\t" +" b.eq ZERO_BETA_G_4_5_6_7_8_9 \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31) +" ZERO_BETA_G_4_5_6_7_8_9: \n\t" +GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c new file mode 100644 index 0000000000..3d25719d92 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -0,0 +1,266 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx7 microkernels. +#include "armsve_asm_2vx7cmplx.h" + +void bli_zgemm_armsve_asm_2vx7_unindexed + ( + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #7 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" \n\t" +" ld1rd z14.d, p0/z, [%1, 8*0] \n\t" // Load B's real & imaginary. +" ld1rd z15.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z16.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z17.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z18.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z19.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z20.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z21.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*7] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*9] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*11] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*13] \n\t" +" add %1, %1, x3 \n\t" +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL14(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13) +" \n\t" +" cmp %5, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z30,z31,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z30,z31,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2) +" ld1rd z14.d, p0/z, [%1, 8*0] \n\t" +" ld1rd z15.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z16.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z17.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z18.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z19.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z20.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z21.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*7] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*9] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*11] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*13] \n\t" +" add %1, %1, x3 \n\t" +GEMM_2VX7CMPLX_MKER_LOOP_PLAIN_C_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z1,z3,z5,z7,z9,z11,z13,p0,z28,z29,z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z28.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z29.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z30.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z31.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +GEMM_FMULCMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z28,z29) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +GEMM_CCMPLX_LOAD_COL7_C(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,p0,x9,%4) +GEMM_FMLACMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z30,z31) +GEMM_CCMPLX_STORE_COL7_C(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" index z28.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +GEMM_CCMPLX_LOAD_COL7_G(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,p0,z28,x9,%4,x16) +GEMM_FMLACMPLX_COL7(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z30,z31) +GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,p0,z28,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c new file mode 100644 index 0000000000..d0eef4a8ca --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -0,0 +1,290 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_dcomplex.h" + +// 2vx8 microkernels. +#include "armsve_asm_2vx8cmplx.h" + +void bli_zgemm_armsve_asm_2vx8_unindexed + ( + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t info = 0; + + __asm__ volatile ( +// " ldr x0, %[a] \n\t" +// " ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #1 \n\t" // Column-skip of A. +" mov x3, #8 \n\t" // Row-skip of B. +" \n\t" +// " ldr x2, %[c] \n\t" +// " ldr x3, %[rs_c] \n\t" // Row-skip of C. +// " ldr x4, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %0, %0, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %1, %1, x16 \n\t" +" mov x16, 0x3 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr %2, %2, x16 \n\t" +#endif +" \n\t" +" mov x16, #16 \n\t" // Multiply some address skips by sizeof(dcomplex). +" madd x2, x16, x2, xzr \n\t" // cs_a +" madd x3, x16, x3, xzr \n\t" // rs_b +" madd %4, x16, %4, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +// " ldr x5, %[k_mker] \n\t" // Number of loops. +// " ldr x6, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp %5, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Load B's real & half of imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" ld1rd z28.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z29.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z30.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z31.d, p0/z, [%1, 8*7] \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp %3, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, %2 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, %4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL16(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15) +" \n\t" +" cmp %5, #0 \n\t" // If no 6-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z18,z19,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +" subs %5, %5, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_3_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp %6, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z16,z17,p0,%0,x2) +" ld1rd z20.d, p0/z, [%1, 8*0] \n\t" // Reload B's real & half of imaginary. +" ld1rd z21.d, p0/z, [%1, 8*2] \n\t" +" ld1rd z22.d, p0/z, [%1, 8*4] \n\t" +" ld1rd z23.d, p0/z, [%1, 8*6] \n\t" +" ld1rd z24.d, p0/z, [%1, 8*8] \n\t" +" ld1rd z25.d, p0/z, [%1, 8*10] \n\t" +" ld1rd z26.d, p0/z, [%1, 8*12] \n\t" +" ld1rd z27.d, p0/z, [%1, 8*14] \n\t" +" ld1rd z28.d, p0/z, [%1, 8*1] \n\t" +" ld1rd z29.d, p0/z, [%1, 8*3] \n\t" +" ld1rd z30.d, p0/z, [%1, 8*5] \n\t" +" ld1rd z31.d, p0/z, [%1, 8*7] \n\t" +GEMM_2VX8CMPLX_MKER_LOOP_PLAIN_C_1_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z1,z3,z5,z7,z9,z11,z13,z15,p0,z16,z17,z20,z21,z22,z23,z24,z25,z26,z27,z28,z29,z30,z31,%1,x3) +" sub %6, %6, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x7, %[alpha] \n\t" // Load alpha & beta (address). +// " ldr x8, %[beta] \n\t" +" ld1rd z16.d, p0/z, [%7] \n\t" // Real(alpha). +" ld1rd z17.d, p0/z, [%7, 8] \n\t" // Imag(alpha). +" ld1rd z18.d, p0/z, [%8] \n\t" // Real(beta). +" ld1rd z19.d, p0/z, [%8, 8] \n\t" // Imag(beta). +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x9, %[a_next] \n\t" +// " ldr x10, %[b_next] \n\t" +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr %9, %9, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr %10, %10, x16 \n\t" +#endif +" prfm PLDL1STRM, [%9] \n\t" +" prfm PLDL1STRM, [%9, 256*1] \n\t" +" prfm PLDL1STRM, [%10] \n\t" +" prfm PLDL1STRM, [%10, 256*1] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +GEMM_FMULCMPLX_COL2(z20,z21,z22,z23,p0,z0 ,z1 ,z2 ,z3 ,z16,z17) +GEMM_FMULCMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z4 ,z5 ,z6 ,z7 ,z16,z17) +GEMM_FMULCMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z8 ,z9 ,z10,z11,z16,z17) +GEMM_FMULCMPLX_COL2(z8 ,z9 ,z10,z11,p0,z12,z13,z14,z15,z16,z17) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, %2 \n\t" // C address for loading. +" \n\t" // C address for storing is %2 itself. +" cmp %3, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z24,z25,z26,z27,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_C(z20,z21,z22,z23,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4) +" \n\t" +GEMM_CCMPLX_LOAD_COL2_C(z12,z13,z14,z15,p0,x9,%4) +GEMM_CCMPLX_LOAD_COL2_C(z24,z25,z26,z27,p0,x9,%4) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4) +GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" +" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2, +" index z16.d, xzr, %3 \n\t" // s.t. 2*sizeof(double) = 2*8 = 16. +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z16,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z24,z25,z26,z27,p0,z16,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z16,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z16,%2,%4,x16) +" \n\t" +GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z16,x9,%4,x16) +GEMM_CCMPLX_LOAD_COL2_G(z24,z25,z26,z27,p0,z16,x9,%4,x16) +GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z12,z13,z14,z15,z18,z19) +GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z24,z25,z26,z27,z18,z19) +GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z16,%2,%4,x16) +GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_EXEC: \n\t" +" mov %11, #0 \n\t" // Return normal. +: "+r" (a), // %0 + "+r" (b), // %1 + "+r" (c), // %2 + "+r" (rs_c), // %3 + "+r" (cs_c), // %4 + "+r" (k_mker), // %5 + "+r" (k_left), // %6 + "+r" (alpha), // %7 + "+r" (beta), // %8 + "+r" (a_next), // %9 + "+r" (b_next), // %10 + "=r" (info) // %11 +: +: "x2","x3","x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve256_asm_d8x8.c b/kernels/armsve/3/old/bli_gemm_armsve256_asm_d8x8.c similarity index 100% rename from kernels/armsve/3/bli_gemm_armsve256_asm_d8x8.c rename to kernels/armsve/3/old/bli_gemm_armsve256_asm_d8x8.c diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_sh2vx10_unindexed.c b/kernels/armsve/3/old/bli_gemm_armsve_asm_sh2vx10_unindexed.c similarity index 100% rename from kernels/armsve/3/bli_gemm_armsve_asm_sh2vx10_unindexed.c rename to kernels/armsve/3/old/bli_gemm_armsve_asm_sh2vx10_unindexed.c diff --git a/kernels/armsve/3/sup/bli_gemmsup_armsve_ref.c b/kernels/armsve/3/old/sup/bli_gemmsup_armsve_ref.c similarity index 100% rename from kernels/armsve/3/sup/bli_gemmsup_armsve_ref.c rename to kernels/armsve/3/old/sup/bli_gemmsup_armsve_ref.c diff --git a/kernels/armsve/3/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/old/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c similarity index 100% rename from kernels/armsve/3/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c rename to kernels/armsve/3/old/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c diff --git a/kernels/armsve/3/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/old/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c similarity index 100% rename from kernels/armsve/3/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c rename to kernels/armsve/3/old/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c diff --git a/kernels/armsve/bli_kernels_armsve.h b/kernels/armsve/bli_kernels_armsve.h index 3ccd79b68e..0d5c5dc472 100644 --- a/kernels/armsve/bli_kernels_armsve.h +++ b/kernels/armsve/bli_kernels_armsve.h @@ -35,11 +35,18 @@ GEMM_UKR_PROT( double, d, gemm_armsve256_asm_8x8 ) GEMM_UKR_PROT( double, d, gemm_armsve_asm_2vx10_unindexed ) GEMM_UKR_PROT( float, s, gemm_armsve_asm_2vx10_unindexed ) -GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_2vx10_unindexed ) -GEMMSUP_KER_PROT( double, d, gemmsup_cv_armsve_2vx10_unindexed ) -GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_10x2v_unindexed ) +GEMM_UKR_PROT( scomplex, c, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx8_unindexed ) +GEMM_UKR_PROT( dcomplex, z, gemm_armsve_asm_2vx7_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_2vx10_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_cv_armsve_2vx10_unindexed ) +//GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_10x2v_unindexed ) -PACKM_KER_PROT( double, d, packm_armsve256_asm_8xk ) +// Use SVE intrinsics only for referred cases. +#if (defined(BLIS_FAMILY_ARMSVE) && !defined(BLIS_FAMILY_A64FX)) +PACKM_KER_PROT( double, d, packm_armsve256_int_8xk ) +PACKM_KER_PROT( double, d, packm_armsve512_int_12xk ) +#endif PACKM_KER_PROT( double, d, packm_armsve512_asm_16xk ) -PACKM_KER_PROT( double, d, packm_armsve512_asm_12xk ) PACKM_KER_PROT( double, d, packm_armsve512_asm_10xk ) diff --git a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c index e502a34ed6..b9db587266 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c @@ -330,53 +330,53 @@ void bli_dgemm_armv7a_int_4x4 double b0, b1, b2, b3; double B0, B1, B2, B3; - double ab00, ab01, ab02, ab03; - double ab10, ab11, ab12, ab13; + double ab00, ab01, ab02, ab03; + double ab10, ab11, ab12, ab13; double ab20, ab21, ab22, ab23; - double ab30, ab31, ab32, ab33; + double ab30, ab31, ab32, ab33; - double* restrict c00, * restrict c01, * restrict c02, * restrict c03; + double* restrict c00, * restrict c01, * restrict c02, * restrict c03; double* restrict c10, * restrict c11, * restrict c12, * restrict c13; double* restrict c20, * restrict c21, * restrict c22, * restrict c23; - double* restrict c30, * restrict c31, * restrict c32, * restrict c33; + double* restrict c30, * restrict c31, * restrict c32, * restrict c33; double* restrict ap = a; - double* restrict bp = b; + double* restrict bp = b; double* restrict Ap = a + 4; - double* restrict Bp = b + 4; + double* restrict Bp = b + 4; - c00 = (c + 0*rs_c + 0*cs_c); - c10 = (c + 1*rs_c + 0*cs_c); - c20 = (c + 2*rs_c + 0*cs_c); - c30 = (c + 3*rs_c + 0*cs_c); + c00 = (c + 0*rs_c + 0*cs_c); + c10 = (c + 1*rs_c + 0*cs_c); + c20 = (c + 2*rs_c + 0*cs_c); + c30 = (c + 3*rs_c + 0*cs_c); - c01 = (c + 0*rs_c + 1*cs_c); - c11 = (c + 1*rs_c + 1*cs_c); - c21 = (c + 2*rs_c + 1*cs_c); - c31 = (c + 3*rs_c + 1*cs_c); + c01 = (c + 0*rs_c + 1*cs_c); + c11 = (c + 1*rs_c + 1*cs_c); + c21 = (c + 2*rs_c + 1*cs_c); + c31 = (c + 3*rs_c + 1*cs_c); - c02 = (c + 0*rs_c + 2*cs_c); - c12 = (c + 1*rs_c + 2*cs_c); - c22 = (c + 2*rs_c + 2*cs_c); - c32 = (c + 3*rs_c + 2*cs_c); + c02 = (c + 0*rs_c + 2*cs_c); + c12 = (c + 1*rs_c + 2*cs_c); + c22 = (c + 2*rs_c + 2*cs_c); + c32 = (c + 3*rs_c + 2*cs_c); - c03 = (c + 0*rs_c + 3*cs_c); - c13 = (c + 1*rs_c + 3*cs_c); - c23 = (c + 2*rs_c + 3*cs_c); - c33 = (c + 3*rs_c + 3*cs_c); + c03 = (c + 0*rs_c + 3*cs_c); + c13 = (c + 1*rs_c + 3*cs_c); + c23 = (c + 2*rs_c + 3*cs_c); + c33 = (c + 3*rs_c + 3*cs_c); ab00 = 0.0; ab10 = 0.0; ab20 = 0.0; ab30 = 0.0; ab01 = 0.0; ab11 = 0.0; ab21 = 0.0; ab31 = 0.0; ab02 = 0.0; ab12 = 0.0; ab22 = 0.0; ab32 = 0.0; ab03 = 0.0; ab13 = 0.0; ab23 = 0.0; ab33 = 0.0; - A0 = *(Ap + 0); - A1 = *(Ap + 1); - A2 = *(Ap + 2); - A3 = *(Ap + 3); + A0 = *(Ap + 0); + A1 = *(Ap + 1); + A2 = *(Ap + 2); + A3 = *(Ap + 3); - a0 = *(ap + 0); + a0 = *(ap + 0); a1 = *(ap + 1); a2 = *(ap + 2); @@ -389,11 +389,11 @@ void bli_dgemm_armv7a_int_4x4 b1 = *(bp + 1); b2 = *(bp + 2); - double *Aplast = (Ap + 4*(k-k_left)); + double *Aplast = (Ap + 4*(k-k_left)); //for ( i = 0; i < k_iter; ++i ) // Unroll by factor 4. for ( ; Ap != Aplast ; ) // Unroll by factor 4. - { + { /* Prefetch */ //__asm__ ("pld\t[%0],#100\n\t" : :"r"(Ap) : ); __builtin_prefetch( ap + 112 ); @@ -452,7 +452,7 @@ void bli_dgemm_armv7a_int_4x4 b2 = *(bp + 10); ab03 += a0 * b3; - a0 = *(ap + 8); + a0 = *(ap + 8); ab13 += a1 * b3; a1 = *(ap + 9); ab23 += a2 * b3; @@ -460,17 +460,17 @@ void bli_dgemm_armv7a_int_4x4 ab33 += a3 * b3; //a3 = *(ap + 11); - ap += 8; - Ap += 8; - bp += 8; - Bp += 8; + ap += 8; + Ap += 8; + bp += 8; + Bp += 8; - } + } - for ( i = 0; i < k_left; ++i ) - { - a0 = *(ap + 0); + for ( i = 0; i < k_left; ++i ) + { + a0 = *(ap + 0); a1 = *(ap + 1); a2 = *(ap + 2); a3 = *(ap + 3); @@ -500,48 +500,73 @@ void bli_dgemm_armv7a_int_4x4 ab23 += a2 * b3; ab33 += a3 * b3; - ap += 4; - bp += 4; - } - - *c00 = *c00 * *beta; - *c10 = *c10 * *beta; - *c20 = *c20 * *beta; - *c30 = *c30 * *beta; - - *c01 = *c01 * *beta; - *c11 = *c11 * *beta; - *c21 = *c21 * *beta; - *c31 = *c31 * *beta; - - *c02 = *c02 * *beta; - *c12 = *c12 * *beta; - *c22 = *c22 * *beta; - *c32 = *c32 * *beta; - - *c03 = *c03 * *beta; - *c13 = *c13 * *beta; - *c23 = *c23 * *beta; - *c33 = *c33 * *beta; - - *c00 += ab00 * *alpha; - *c10 += ab10 * *alpha; - *c20 += ab20 * *alpha; - *c30 += ab30 * *alpha; - - *c01 += ab01 * *alpha; - *c11 += ab11 * *alpha; - *c21 += ab21 * *alpha; - *c31 += ab31 * *alpha; - - *c02 += ab02 * *alpha; - *c12 += ab12 * *alpha; - *c22 += ab22 * *alpha; - *c32 += ab32 * *alpha; - - *c03 += ab03 * *alpha; - *c13 += ab13 * *alpha; - *c23 += ab23 * *alpha; - *c33 += ab33 * *alpha; + ap += 4; + bp += 4; + } + + if ( *beta == 0.0 ) + { + *c00 = ab00 * *alpha; + *c10 = ab10 * *alpha; + *c20 = ab20 * *alpha; + *c30 = ab30 * *alpha; + + *c01 = ab01 * *alpha; + *c11 = ab11 * *alpha; + *c21 = ab21 * *alpha; + *c31 = ab31 * *alpha; + + *c02 = ab02 * *alpha; + *c12 = ab12 * *alpha; + *c22 = ab22 * *alpha; + *c32 = ab32 * *alpha; + + *c03 = ab03 * *alpha; + *c13 = ab13 * *alpha; + *c23 = ab23 * *alpha; + *c33 = ab33 * *alpha; + } + else + { + *c00 = *c00 * *beta; + *c10 = *c10 * *beta; + *c20 = *c20 * *beta; + *c30 = *c30 * *beta; + + *c01 = *c01 * *beta; + *c11 = *c11 * *beta; + *c21 = *c21 * *beta; + *c31 = *c31 * *beta; + + *c02 = *c02 * *beta; + *c12 = *c12 * *beta; + *c22 = *c22 * *beta; + *c32 = *c32 * *beta; + + *c03 = *c03 * *beta; + *c13 = *c13 * *beta; + *c23 = *c23 * *beta; + *c33 = *c33 * *beta; + + *c00 += ab00 * *alpha; + *c10 += ab10 * *alpha; + *c20 += ab20 * *alpha; + *c30 += ab30 * *alpha; + + *c01 += ab01 * *alpha; + *c11 += ab11 * *alpha; + *c21 += ab21 * *alpha; + *c31 += ab31 * *alpha; + + *c02 += ab02 * *alpha; + *c12 += ab12 * *alpha; + *c22 += ab22 * *alpha; + *c32 += ab32 * *alpha; + + *c03 += ab03 * *alpha; + *c13 += ab13 * *alpha; + *c23 += ab23 * *alpha; + *c33 += ab33 * *alpha; + } } diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c new file mode 100644 index 0000000000..301b8ad790 --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_d6xk.c @@ -0,0 +1,323 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_dpackm_armv8a_int_6xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 6; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 2; + uint64_t k_left = k0 % 2; + double* a_loc = a; + double* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float64x2_t vkappa = vld1q_dup_f64( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + v4 = vmulq_f64( v4, vkappa ); + v5 = vmulq_f64( v5, vkappa ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 6xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c new file mode 100644 index 0000000000..321fa5403b --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_d8xk.c @@ -0,0 +1,353 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_dpackm_armv8a_int_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + double* restrict kappa, + double* restrict a, inc_t inca0, inc_t lda0, + double* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 2; + uint64_t k_left = k0 % 2; + double* a_loc = a; + double* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_deq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + float64x2_t v3 = vld1q_f64( a_loc + 6 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v6 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v7 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + v6 = vld1q_f64( a_loc + inca * 6 ); + v7 = vld1q_f64( a_loc + inca * 7 ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd3_1 = vtrn1q_f64( v6, v7 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + float64x2_t vd3_2 = vtrn2q_f64( v6, v7 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + vst1q_f64( p_loc + 6, vd3_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + vst1q_f64( p_loc + 6, vd3_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + v3 = vld1q_lane_f64( a_loc + inca * 6, v3, 0 ); + v3 = vld1q_lane_f64( a_loc + inca * 7, v3, 1 ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float64x2_t vkappa = vld1q_dup_f64( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 2 + k_left; ik > 0; --ik ) + { + float64x2_t v0 = vld1q_f64( a_loc + 0 ); + float64x2_t v1 = vld1q_f64( a_loc + 2 ); + float64x2_t v2 = vld1q_f64( a_loc + 4 ); + float64x2_t v3 = vld1q_f64( a_loc + 6 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float64x2_t v0 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v1 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v2 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v3 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v4 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v5 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v6 = (float64x2_t)vdupq_n_u64( 0 ); + float64x2_t v7 = (float64x2_t)vdupq_n_u64( 0 ); + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f64( a_loc + inca * 0 ); + v1 = vld1q_f64( a_loc + inca * 1 ); + v2 = vld1q_f64( a_loc + inca * 2 ); + v3 = vld1q_f64( a_loc + inca * 3 ); + v4 = vld1q_f64( a_loc + inca * 4 ); + v5 = vld1q_f64( a_loc + inca * 5 ); + v6 = vld1q_f64( a_loc + inca * 6 ); + v7 = vld1q_f64( a_loc + inca * 7 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + v4 = vmulq_f64( v4, vkappa ); + v5 = vmulq_f64( v5, vkappa ); + v6 = vmulq_f64( v6, vkappa ); + v7 = vmulq_f64( v7, vkappa ); + + // In-register transpose. + float64x2_t vd0_1 = vtrn1q_f64( v0, v1 ); + float64x2_t vd1_1 = vtrn1q_f64( v2, v3 ); + float64x2_t vd2_1 = vtrn1q_f64( v4, v5 ); + float64x2_t vd3_1 = vtrn1q_f64( v6, v7 ); + float64x2_t vd0_2 = vtrn2q_f64( v0, v1 ); + float64x2_t vd1_2 = vtrn2q_f64( v2, v3 ); + float64x2_t vd2_2 = vtrn2q_f64( v4, v5 ); + float64x2_t vd3_2 = vtrn2q_f64( v6, v7 ); + + vst1q_f64( p_loc + 0, vd0_1 ); + vst1q_f64( p_loc + 2, vd1_1 ); + vst1q_f64( p_loc + 4, vd2_1 ); + vst1q_f64( p_loc + 6, vd3_1 ); + p_loc += ldp; + + vst1q_f64( p_loc + 0, vd0_2 ); + vst1q_f64( p_loc + 2, vd1_2 ); + vst1q_f64( p_loc + 4, vd2_2 ); + vst1q_f64( p_loc + 6, vd3_2 ); + p_loc += ldp; + a_loc += 2 * lda; // 2; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f64( a_loc + inca * 0, v0, 0 ); + v0 = vld1q_lane_f64( a_loc + inca * 1, v0, 1 ); + v1 = vld1q_lane_f64( a_loc + inca * 2, v1, 0 ); + v1 = vld1q_lane_f64( a_loc + inca * 3, v1, 1 ); + v2 = vld1q_lane_f64( a_loc + inca * 4, v2, 0 ); + v2 = vld1q_lane_f64( a_loc + inca * 5, v2, 1 ); + v3 = vld1q_lane_f64( a_loc + inca * 6, v3, 0 ); + v3 = vld1q_lane_f64( a_loc + inca * 7, v3, 1 ); + + // Scale by kappa. + v0 = vmulq_f64( v0, vkappa ); + v1 = vmulq_f64( v1, vkappa ); + v2 = vmulq_f64( v2, vkappa ); + v3 = vmulq_f64( v3, vkappa ); + + vst1q_f64( p_loc + 0, v0 ); + vst1q_f64( p_loc + 2, v1 ); + vst1q_f64( p_loc + 4, v2 ); + vst1q_f64( p_loc + 6, v3 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(dscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + +//bli_dfprintm( stdout, "packm 8xk ker: a_packed", cdim0, k0_max, p, 1, ldp0, "%5.2f", "" ); + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c new file mode 100644 index 0000000000..3718772473 --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_s12xk.c @@ -0,0 +1,435 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_2 _Pragma("unroll 2") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_2 _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_2 +#endif + +void bli_spackm_armv8a_int_12xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 12; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + float* a_loc = a; + float* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + float32x4_t v2 = vld1q_f32( a_loc + 8 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v8 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v9 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v10 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v11 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + v8 = vld1q_f32( a_loc + inca * 8 ); + v9 = vld1q_f32( a_loc + inca * 9 ); + v10 = vld1q_f32( a_loc + inca * 10 ); + v11 = vld1q_f32( a_loc + inca * 11 ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 8-11 + vt0 = vtrn1q_f32( v8, v9 ); + vt1 = vtrn2q_f32( v8, v9 ); + vt2 = vtrn1q_f32( v10, v11 ); + vt3 = vtrn2q_f32( v10, v11 ); + v8 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v9 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v10 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v11 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + vst1q_f32( p_loc + 8, v8 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + vst1q_f32( p_loc + 8, v9 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + vst1q_f32( p_loc + 8, v10 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + vst1q_f32( p_loc + 8, v11 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + v2 = vld1q_lane_f32( a_loc + inca * 8 , v2, 0 ); + v2 = vld1q_lane_f32( a_loc + inca * 9 , v2, 1 ); + v2 = vld1q_lane_f32( a_loc + inca * 10, v2, 2 ); + v2 = vld1q_lane_f32( a_loc + inca * 11, v2, 3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float32x4_t vkappa = vld1q_dup_f32( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_2 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + float32x4_t v2 = vld1q_f32( a_loc + 8 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v8 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v9 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v10 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v11 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + v8 = vld1q_f32( a_loc + inca * 8 ); + v9 = vld1q_f32( a_loc + inca * 9 ); + v10 = vld1q_f32( a_loc + inca * 10 ); + v11 = vld1q_f32( a_loc + inca * 11 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + v3 = vmulq_f32( v3, vkappa ); + v4 = vmulq_f32( v4, vkappa ); + v5 = vmulq_f32( v5, vkappa ); + v6 = vmulq_f32( v6, vkappa ); + v7 = vmulq_f32( v7, vkappa ); + v8 = vmulq_f32( v8, vkappa ); + v9 = vmulq_f32( v9, vkappa ); + v10 = vmulq_f32( v10, vkappa ); + v11 = vmulq_f32( v11, vkappa ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 8-11 + vt0 = vtrn1q_f32( v8, v9 ); + vt1 = vtrn2q_f32( v8, v9 ); + vt2 = vtrn1q_f32( v10, v11 ); + vt3 = vtrn2q_f32( v10, v11 ); + v8 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v9 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v10 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v11 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + vst1q_f32( p_loc + 8, v8 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + vst1q_f32( p_loc + 8, v9 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + vst1q_f32( p_loc + 8, v10 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + vst1q_f32( p_loc + 8, v11 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + v2 = vld1q_lane_f32( a_loc + inca * 8 , v2, 0 ); + v2 = vld1q_lane_f32( a_loc + inca * 9 , v2, 1 ); + v2 = vld1q_lane_f32( a_loc + inca * 10, v2, 2 ); + v2 = vld1q_lane_f32( a_loc + inca * 11, v2, 3 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + vst1q_f32( p_loc + 8, v2 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c b/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c new file mode 100644 index 0000000000..3d363c2d8d --- /dev/null +++ b/kernels/armv8a/1m/bli_packm_armv8a_int_s8xk.c @@ -0,0 +1,373 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL_4 _Pragma("unroll 4") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL_4 _Pragma("GCC unroll 4") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL_4 +#endif + +void bli_spackm_armv8a_int_8xk + ( + conj_t conja, + pack_t schema, + dim_t cdim0, + dim_t k0, + dim_t k0_max, + float* restrict kappa, + float* restrict a, inc_t inca0, inc_t lda0, + float* restrict p, inc_t ldp0, + cntx_t* restrict cntx + ) +{ + // This is the panel dimension assumed by the packm kernel. + const dim_t mnr = 8; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + float* a_loc = a; + float* p_loc = p; + + // NOTE: For the purposes of the comments in this packm kernel, we + // interpret inca and lda as rs_a and cs_a, respectively, and similarly + // interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading + // this packm kernel, you should think of the operation as packing an + // m x n micropanel, where m and n are tiny and large, respectively, and + // where elements of each column of the packed matrix P are contiguous. + // (This packm kernel can still be used to pack micropanels of matrix B + // in a gemm operation.) + const uint64_t inca = inca0; + const uint64_t lda = lda0; + const uint64_t ldp = ldp0; + + const bool gs = ( inca0 != 1 && lda0 != 1 ); + + // NOTE: If/when this kernel ever supports scaling by kappa within the + // assembly region, this constraint should be lifted. + const bool unitk = bli_seq1( *kappa ); + + + // ------------------------------------------------------------------------- + + if ( cdim0 == mnr && !gs ) + { + if ( unitk ) + { + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_4 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + else // if ( !unitk ) + { + float32x4_t vkappa = vld1q_dup_f32( kappa ); + + if ( inca == 1 ) + { + // No need to use k-loops here. + // Simply let compiler to expand loops. + PRAGMA_UNROLL_4 + for ( dim_t ik = k_iter * 4 + k_left; ik > 0; --ik ) + { + float32x4_t v0 = vld1q_f32( a_loc + 0 ); + float32x4_t v1 = vld1q_f32( a_loc + 4 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + + a_loc += lda; + p_loc += ldp; + } + } + else // if ( lda == 1 ) + { + float32x4_t v0 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v1 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v2 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v3 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v4 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v5 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v6 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t v7 = (float32x4_t)vdupq_n_u32( 0 ); + float32x4_t vt0; + float32x4_t vt1; + float32x4_t vt2; + float32x4_t vt3; + + PRAGMA_NOUNROLL + for ( ; k_iter > 0; --k_iter ) + { + v0 = vld1q_f32( a_loc + inca * 0 ); + v1 = vld1q_f32( a_loc + inca * 1 ); + v2 = vld1q_f32( a_loc + inca * 2 ); + v3 = vld1q_f32( a_loc + inca * 3 ); + v4 = vld1q_f32( a_loc + inca * 4 ); + v5 = vld1q_f32( a_loc + inca * 5 ); + v6 = vld1q_f32( a_loc + inca * 6 ); + v7 = vld1q_f32( a_loc + inca * 7 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + v2 = vmulq_f32( v2, vkappa ); + v3 = vmulq_f32( v3, vkappa ); + v4 = vmulq_f32( v4, vkappa ); + v5 = vmulq_f32( v5, vkappa ); + v6 = vmulq_f32( v6, vkappa ); + v7 = vmulq_f32( v7, vkappa ); + + // In-register transpose. + // + // Column 0-3 + vt0 = vtrn1q_f32( v0, v1 ); + vt1 = vtrn2q_f32( v0, v1 ); + vt2 = vtrn1q_f32( v2, v3 ); + vt3 = vtrn2q_f32( v2, v3 ); + v0 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v1 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v2 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v3 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + // Column 4-7 + vt0 = vtrn1q_f32( v4, v5 ); + vt1 = vtrn2q_f32( v4, v5 ); + vt2 = vtrn1q_f32( v6, v7 ); + vt3 = vtrn2q_f32( v6, v7 ); + v4 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v5 = (float32x4_t)vtrn1q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + v6 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt0, (float64x2_t)vt2 ); + v7 = (float32x4_t)vtrn2q_f64( (float64x2_t)vt1, (float64x2_t)vt3 ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v4 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v1 ); + vst1q_f32( p_loc + 4, v5 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v2 ); + vst1q_f32( p_loc + 4, v6 ); + p_loc += ldp; + + vst1q_f32( p_loc + 0, v3 ); + vst1q_f32( p_loc + 4, v7 ); + p_loc += ldp; + a_loc += 4 * lda; // 4; + } + for ( ; k_left > 0; --k_left ) + { + v0 = vld1q_lane_f32( a_loc + inca * 0 , v0, 0 ); + v0 = vld1q_lane_f32( a_loc + inca * 1 , v0, 1 ); + v0 = vld1q_lane_f32( a_loc + inca * 2 , v0, 2 ); + v0 = vld1q_lane_f32( a_loc + inca * 3 , v0, 3 ); + v1 = vld1q_lane_f32( a_loc + inca * 4 , v1, 0 ); + v1 = vld1q_lane_f32( a_loc + inca * 5 , v1, 1 ); + v1 = vld1q_lane_f32( a_loc + inca * 6 , v1, 2 ); + v1 = vld1q_lane_f32( a_loc + inca * 7 , v1, 3 ); + + // Scale by kappa. + v0 = vmulq_f32( v0, vkappa ); + v1 = vmulq_f32( v1, vkappa ); + + vst1q_f32( p_loc + 0, v0 ); + vst1q_f32( p_loc + 4, v1 ); + p_loc += ldp; + a_loc += lda; // 1; + } + } + } + } + else // if ( cdim0 < mnr || gs ) + { + PASTEMAC(sscal2m,BLIS_TAPI_EX_SUF) + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim0, + k0, + kappa, + a, inca0, lda0, + p, 1, ldp0, + cntx, + NULL + ); + + if ( cdim0 < mnr ) + { + // Handle zero-filling along the "long" edge of the micropanel. + + const dim_t i = cdim0; + const dim_t m_edge = mnr - cdim0; + const dim_t n_edge = k0_max; + float* restrict p_edge = p + (i )*1; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( k0 < k0_max ) + { + // Handle zero-filling along the "short" (far) edge of the micropanel. + + const dim_t j = k0; + const dim_t m_edge = mnr; + const dim_t n_edge = k0_max - k0; + float* restrict p_edge = p + (j )*ldp; + + bli_sset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} + diff --git a/frame/1m/packm/bli_packm_cxk_rih.h b/kernels/armv8a/3/armv8a_asm_d2x2.h similarity index 76% rename from frame/1m/packm/bli_packm_cxk_rih.h rename to kernels/armv8a/3/armv8a_asm_d2x2.h index c1d2ba9fe3..5bb0bb4d39 100644 --- a/frame/1m/packm/bli_packm_cxk_rih.h +++ b/kernels/armv8a/3/armv8a_asm_d2x2.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -30,25 +31,25 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ +*/ -#undef GENTPROTCO -#define GENTPROTCO( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t panel_dim, \ - dim_t panel_dim_max, \ - dim_t panel_len, \ - dim_t panel_len_max, \ - ctype* kappa, \ - ctype* a, inc_t inca, inc_t lda, \ - ctype* p, inc_t ldp, \ - cntx_t* cntx \ - ); - -INSERT_GENTPROTCO_BASIC0( packm_cxk_rih ) +/* C A B + * || <- | * -- + * || | + * + * or: + * C B * A + * -- <- | -- + * -- | + */ +#define DGEMM_2X2_NANOKERNEL(C0,C1,A,B) \ +" fmla v"#C0".2d, v"#A".2d, v"#B".d[0] \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B".d[1] \n\t" + +#define SGEMM_4X4_NANOKERNEL(C0,C1,C2,C3,A,B) \ +" fmla v"#C0".4s, v"#A".4s, v"#B".s[0] \n\t" \ +" fmla v"#C1".4s, v"#A".4s, v"#B".s[1] \n\t" \ +" fmla v"#C2".4s, v"#A".4s, v"#B".s[2] \n\t" \ +" fmla v"#C3".4s, v"#A".4s, v"#B".s[3] \n\t" diff --git a/kernels/armv8a/3/armv8a_asm_utils.h b/kernels/armv8a/3/armv8a_asm_utils.h index 7bf97d555c..5cb0bad69c 100644 --- a/kernels/armv8a/3/armv8a_asm_utils.h +++ b/kernels/armv8a/3/armv8a_asm_utils.h @@ -47,3 +47,73 @@ #define BRANCH(str) "b ." #str" \n\t" #endif +// Clear vectors. +#define CLEAR1V(V) \ +" dup v"#V".2d, xzr \n\t" +#define CLEAR2V(V0,V1) \ + CLEAR1V(V0) \ + CLEAR1V(V1) +#define CLEAR4V(V0,V1,V2,V3) \ + CLEAR2V(V0,V1) \ + CLEAR2V(V2,V3) +#define CLEAR8V(V0,V1,V2,V3,V4,V5,V6,V7) \ + CLEAR4V(V0,V1,V2,V3) \ + CLEAR4V(V4,V5,V6,V7) + +// Scale vectors. +#define DSCALE1V(V,A,IDX) \ +" fmul v"#V".2d, v"#V".2d, v"#A".d["#IDX"] \n\t" +#define DSCALE2V(V0,V1,A,IDX) \ + DSCALE1V(V0,A,IDX) \ + DSCALE1V(V1,A,IDX) +#define DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V0,V1,A,IDX) \ + DSCALE2V(V2,V3,A,IDX) +#define DSCALE8V(V0,V1,V2,V3,V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) + +// Scale-accumulate. +#define DSCALEA1V(D,S,A,IDX) \ +" fmla v"#D".2d, v"#S".2d, v"#A".d["#IDX"] \n\t" +#define DSCALEA2V(D0,D1,S0,S1,A,IDX) \ + DSCALEA1V(D0,S0,A,IDX) \ + DSCALEA1V(D1,S1,A,IDX) +#define DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D0,D1,S0,S1,A,IDX) \ + DSCALEA2V(D2,D3,S2,S3,A,IDX) +#define DSCALEA8V(D0,D1,D2,D3,D4,D5,D6,D7,S0,S1,S2,S3,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) + +// Load one line. +#define DLOAD1V(V,ADDR,SHIFT) \ +" ldr q"#V", ["#ADDR", #"#SHIFT"] \n\t" +#define DLOAD2V(V0,V1,ADDR,SHIFT) \ + DLOAD1V(V0,ADDR,SHIFT) \ + DLOAD1V(V1,ADDR,SHIFT+16) +#define DLOAD4V(V0,V1,V2,V3,ADDR,SHIFT) \ + DLOAD2V(V0,V1,ADDR,SHIFT) \ + DLOAD2V(V2,V3,ADDR,SHIFT+32) + +// Generic: load one line. +#define DLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \ +" ld1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" ld1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + +// Store one line. +#define DSTORE1V(V,ADDR,SHIFT) \ +" str q"#V", ["#ADDR", #"#SHIFT"] \n\t" +#define DSTORE2V(V0,V1,ADDR,SHIFT) \ + DSTORE1V(V0,ADDR,SHIFT) \ + DSTORE1V(V1,ADDR,SHIFT+16) +#define DSTORE4V(V0,V1,V2,V3,ADDR,SHIFT) \ + DSTORE2V(V0,V1,ADDR,SHIFT) \ + DSTORE2V(V2,V3,ADDR,SHIFT+32) + +// Generic: store one line. +#define DSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \ +" st1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" st1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + + diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c new file mode 100644 index 0000000000..0dbfbcf6b1 --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d4x4.c @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +#define DGEMM_4X4_MKER_LOOP_PLAIN(C00,C10,C01,C11,C02,C12,C03,C13,A0,A1,B0,B1) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) + +// For contiguous storage of C. +#define DLOADC_2V_C_FWD(C0,C1,CADDR,CSHIFT,LDC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_2V_C_FWD(C0,C1,CADDR,CSHIFT,LDC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +void bli_dgemm_armv8a_asm_4x4 + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #4 \n\t" // Column-skip of A. +" mov x3, #4 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:7 ] <- C +// V[ 8:19] <- B +// V[20:31] <- A +// Under this scheme, the following is defined: +#define DGEMM_4X4_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1) \ + DGEMM_4X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,A0,A1,B0,B1) +// TODO: Prefetch C. +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q20, [x0, #16*0] \n\t" +" ldr q21, [x0, #16*1] \n\t" +" ldr q22, [x0, #16*2] \n\t" +" ldr q23, [x0, #16*3] \n\t" +" ldr q24, [x0, #16*4] \n\t" +" ldr q25, [x0, #16*5] \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" ldr q30, [x0, #16*4] \n\t" +" ldr q31, [x0, #16*5] \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q8, [x1, #16*0] \n\t" +" ldr q9, [x1, #16*1] \n\t" +" ldr q10, [x1, #16*2] \n\t" +" ldr q11, [x1, #16*3] \n\t" +" ldr q12, [x1, #16*4] \n\t" +" ldr q13, [x1, #16*5] \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" ldr q14, [x1, #16*0] \n\t" +" ldr q15, [x1, #16*1] \n\t" +" ldr q16, [x1, #16*2] \n\t" +" ldr q17, [x1, #16*3] \n\t" +" ldr q18, [x1, #16*4] \n\t" +" ldr q19, [x1, #16*5] \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1) \ + DGEMM_4X4_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1) \ + "ldr q"#A0", [x0, #16*0] \n\t" \ + "ldr q"#A1", [x0, #16*1] \n\t" \ + "add x0, x0, x2 \n\t" \ + "ldr q"#B0", [x1, #16*0] \n\t" \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,8,9) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,10,11) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,12,13) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,14,15) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,16,17) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC_FWD(30,31,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(20,21,8,9) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(22,23,10,11) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(24,25,12,13) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(26,27,14,15) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(28,29,16,17) +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(30,31,18,19) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q20, [x0, #16*0] \n\t" +" ldr q21, [x0, #16*1] \n\t" +" add x0, x0, x2 \n\t" +" ldr q8, [x1, #16*0] \n\t" +" ldr q9, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X4_MKER_LOOP_PLAIN_LOC(20,21,8,9) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr d8, [x4] \n\t" // Load alpha & beta (value). +" ldr d9, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_C) +DLOADC_2V_C_FWD(10,11,x9,0,x7) +DLOADC_2V_C_FWD(12,13,x9,0,x7) +DLOADC_2V_C_FWD(14,15,x9,0,x7) +DLOADC_2V_C_FWD(16,17,x9,0,x7) +DSCALE8V(10,11,12,13,14,15,16,17,9,0) +DSCALEA8V(10,11,12,13,14,15,16,17,0,1,2,3,4,5,6,7,8,0) +DSTOREC_2V_C_FWD(10,11,x5,0,x7) +DSTOREC_2V_C_FWD(12,13,x5,0,x7) +DSTOREC_2V_C_FWD(14,15,x5,0,x7) +DSTOREC_2V_C_FWD(16,17,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +// TODO: Implement. +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + +} diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c new file mode 100644 index 0000000000..2fe83438f5 --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d6x8r.c @@ -0,0 +1,356 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_ ##LOADNEXT (A0,AADDR,ASHIFT) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (A1,AADDR,ASHIFT+16) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +// For contiguous storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DPRFMC_FWD(CADDR,RSC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For scattered storage of C. +#define DLOADC_GATHER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DLOAD1V_GATHER_ELMFWD(C0,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C1,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C2,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DSTOREC_SCATTER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DSTORE1V_SCATTER_ELMFWD(C0,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C1,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C2,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + + +void bli_dgemm_armv8a_asm_6x8r + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #6 \n\t" // Column-skip of A. +" mov x3, #8 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x9, x5 \n\t" +" cmp x7, #8 \n\t" // Do not prefetch C for generic strided. +BNE(C_PREFETCH_END) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +DPRFMC_FWD(x9,x6) +LABEL(C_PREFETCH_END) +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q24, [x0, #16*0] \n\t" // Load A. +" ldr q25, [x0, #16*1] \n\t" +" ldr q26, [x0, #16*2] \n\t" +" add x0, x0, x2 \n\t" +" ldr q27, [x0, #16*0] \n\t" +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x0,1*16,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "ldr q"#A2", [x0, #16*0] \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x0,1*16,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q24, [x0, #16*0] \n\t" // Load A col. +" ldr q25, [x0, #16*1] \n\t" +" ldr q26, [x0, #16*2] \n\t" +" add x0, x0, x2 \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_R) +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" // This conditional flag will be used +" \n\t" // multiple times for skipping load. +// Row 0: +BEQ(ZERO_BETA_R_0) +DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_0) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_R_1_2) +DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_1_2) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_R_3_4_5) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6) +DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_R_3_4_5) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" +// Row 0: +BEQ(ZERO_BETA_G_0) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_G_0) +DSTOREC_SCATTER_4V_R_FWD(0,1,2,3,x5,x1,x7,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_G_1_2) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_G_1_2) +DSTOREC_SCATTER_4V_R_FWD(4,5,6,7,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(8,9,10,11,x5,x1,x7,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_G_3_4_5) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(4,5,6,7,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(8,9,10,11,x9,x0,x7,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_G_3_4_5) +DSTOREC_SCATTER_4V_R_FWD(12,13,14,15,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(16,17,18,19,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(20,21,22,23,x5,x1,x7,x6) +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8","x9", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); +} + diff --git a/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c new file mode 100644 index 0000000000..129c3613ac --- /dev/null +++ b/kernels/armv8a/3/old/bli_gemm_armv8a_asm_d8x4.c @@ -0,0 +1,294 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "armv8a_asm_utils.h" + +// Nanokernel operations. +#include "armv8a_asm_d2x2.h" + +/* Order of DGEMM_8x4's execution in 2x2 blocks: + * + * +---+ +---+ + * | 0 | | 2 | + * +---+ +---+ + * +---+ +---+ + * | 1 | | 3 | + * +---+ +---+ + * +---+ +---+ + * | 4 | | 6 | + * +---+ +---+ + * +---+ +---+ + * | 5 | | 7 | + * +---+ +---+ + * + */ +#define DGEMM_8X4_MKER_LOOP_PLAIN(C00,C10,C20,C30,C01,C11,C21,C31,C02,C12,C22,C32,C03,C13,C23,C33,A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) \ + DGEMM_LOAD2V_ ##LOADNEXT (A0,A1,AADDR,ASHIFT) \ + DGEMM_2X2_NANOKERNEL(C20,C21,A2,B0) \ + DGEMM_2X2_NANOKERNEL(C30,C31,A3,B0) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C22,C23,A2,B1) \ + DGEMM_2X2_NANOKERNEL(C32,C33,A3,B1) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +// For contiguous storage of C. +#define DLOADC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +void bli_dgemm_armv8a_asm_8x4 + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // This kernel is a WIP. + // I have no generic stride support at this moment. + assert( rs_c0 == 1 ); + // if ( rs_c0 != 1 ) return ; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, #8 \n\t" // Column-skip of A. +" mov x3, #4 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:21] <- B +// V[22:29] <- A +// Under this scheme, the following is defined: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_8X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,A2,A3,B0,B1,AADDR,ASHIFT,BADDR,BSHIFT,LOADNEXT) +// TODO: Prefetch C. +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q22, [x0, #16*0] \n\t" +" ldr q23, [x0, #16*1] \n\t" +" ldr q24, [x0, #16*2] \n\t" +" ldr q25, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q16, [x1, #16*0] \n\t" +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q18, [x1, #16*0] \n\t" +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q20, [x1, #16*0] \n\t" +" ldr q21, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1) \ + DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,x0,0,x1,0,load) \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "ldr q"#A2", [x0, #16*2] \n\t" \ + "ldr q"#A3", [x0, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" \ + "add x0, x0, x2 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,20,21) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(22,23,24,25,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,20,21) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(26,27,28,29,16,17,x0,0,x1,0,noload) +" ldr q26, [x0, #16*0] \n\t" +" ldr q27, [x0, #16*1] \n\t" +" ldr q28, [x0, #16*2] \n\t" +" ldr q29, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(22,23,24,25,18,19,xzr,-1,xzr,-1,noload) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(26,27,28,29,20,21,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q22, [x0, #16*0] \n\t" // Load A col. +" ldr q23, [x0, #16*1] \n\t" +" ldr q24, [x0, #16*2] \n\t" +" ldr q25, [x0, #16*3] \n\t" +" add x0, x0, x2 \n\t" +" ldr q16, [x1, #16*0] \n\t" // Load B col. +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(22,23,24,25,16,17,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr d16, [x4] \n\t" // Load alpha & beta (value). +" ldr d17, [x8] \n\t" +" \n\t" +LABEL(PREFETCH_ABNEXT) +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, 64*0] \n\t" // Do not know cache line size, +" prfm PLDL1STRM, [x0, 64*1] \n\t" // issue some number of prfm instructions +" prfm PLDL1STRM, [x0, 64*2] \n\t" // to try to activate hardware prefetcher. +" prfm PLDL1STRM, [x1, 64*0] \n\t" +" prfm PLDL1STRM, [x1, 64*1] \n\t" +" prfm PLDL1STRM, [x1, 64*3] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" // Check for generic storage. +BNE(WRITE_MEM_G) +// +// Contiguous C-storage. +LABEL(WRITE_MEM_C) +DLOADC_4V_C_FWD(20,21,22,23,x9,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x9,0,x7) +DSCALE8V(20,21,22,23,24,25,26,27,17,0) +DSCALEA8V(20,21,22,23,24,25,26,27,0,1,2,3,4,5,6,7,16,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x9,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x9,0,x7) +DSCALE8V(0,1,2,3,4,5,6,7,17,0) +DSCALEA8V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,0) +// +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(0,1,2,3,x5,0,x7) +DSTOREC_4V_C_FWD(4,5,6,7,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// Generic-strided C-storage. +LABEL(WRITE_MEM_G) +// TODO: Implement. +LABEL(END_WRITE_MEM) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c b/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c new file mode 100644 index 0000000000..c87ff1feb6 --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_armv8a_ref.c @@ -0,0 +1,450 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Separate instantiation for Armv8-A reference kernels. +// Temporary workaround. Will be removed after upstream has switched to a better way +// of exposing gemmsup interface. + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, _armv8a, _ref2 ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, _armv8a, _ref2 ) + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c new file mode 100644 index 0000000000..630459db73 --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c @@ -0,0 +1,509 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +#define DGEMM_3X1X2_NKER_SUBLOOP(C0,C1,C2,A0,A1,A2,B) \ +" fmla v"#C0".2d, v"#A0".2d, v"#B".2d \n\t" \ +" fmla v"#C1".2d, v"#A1".2d, v"#B".2d \n\t" \ +" fmla v"#C2".2d, v"#A2".2d, v"#B".2d \n\t" + +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C04,C05,C06,C07,C10,C11,C12,C13,C14,C15,C16,C17,C20,C21,C22,C23,C24,C25,C26,C27,A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) \ + /* Always load before forwarding to the next line. */ \ + DGEMM_3X1X2_NKER_SUBLOOP(C00,C10,C20,A0,A1,A2,B0) \ + DGEMM_LOAD1V_K_load(B0,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C01,C11,C21,A0,A1,A2,B1) \ + DGEMM_LOAD1V_K_load(B1,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C02,C12,C22,A0,A1,A2,B2) \ + DGEMM_LOAD1V_K_load(B2,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C03,C13,C23,A0,A1,A2,B3) \ + DGEMM_LOAD1V_K_load(B3,BELEMADDR,BELEMST) \ + \ +" add "#BADDR", "#BADDR", #16 \n\t" \ +" mov "#BELEMADDR", "#BADDR" \n\t" \ + DGEMM_3X1X2_NKER_SUBLOOP(C04,C14,C24,A0,A1,A2,B0) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B0,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C05,C15,C25,A0,A1,A2,B1) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B1,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C06,C16,C26,A0,A1,A2,B2) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B2,BELEMADDR,BELEMST) \ + DGEMM_3X1X2_NKER_SUBLOOP(C07,C17,C27,A0,A1,A2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (B3,BELEMADDR,BELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE12V(V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V8,V9,V10,V11,A,IDX) +#define DSCALEA12V(D0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,D11,S0,S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,S11,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D8,D9,D10,D11,S8,S9,S10,S11,A,IDX) + +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +void bli_dgemmsup_rd_armv8a_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( n0 != 8 ) + { + if ( n0 < 8 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + dim_t m = m0; + double *a_loc = a; + double *c_loc = c; + + for ( ; m >= 3; m -= 3 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a_loc, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + a_loc += 3 * rs_a0; + c_loc += 3 * rs_c0; + } + + if ( m > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, m, 4, k0, + alpha, a_loc, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + } + b += 4 * cs_b0; + c += 4 * cs_c0; + } + + for ( ; m0 > 0; m0 -= 3 ) + { + dim_t m_loc = ( m0 < 3 ) ? m0 : 3; + + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, m_loc, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + + a += 3 * rs_a0; + c += 3 * rs_c0; + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 3; + int64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:26] <- A +// V[28:31] <- B +// V[ 27 ] <- Not used. +// Under this scheme, the following is defined: +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) \ + DGEMM_3X8X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,BADDR,BELEMADDR,BELEMST,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q28, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q29, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q30, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q31, [x11] \n\t" +" add x11, x11, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +// " add x14, x14, x2 \n\t" +" add x0, x0, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x1,x11,x3,load) \ + "mov x14, x0 \n\t" \ + "ldr q24, [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q25, [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q26, [x14] \n\t" \ + /*"add x14, x14, x2 \n\t"*/ \ + "add x0, x0, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_3X8X2_K_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,x1,x11,x3,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 1. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" faddp v6.2d, v12.2d, v13.2d \n\t" +" faddp v7.2d, v14.2d, v15.2d \n\t" +" faddp v8.2d, v16.2d, v17.2d \n\t" // Line 2. +" faddp v9.2d, v18.2d, v19.2d \n\t" +" faddp v10.2d, v20.2d, v21.2d \n\t" +" faddp v11.2d, v22.2d, v23.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" ld1 {v30.d}[0], [x11], x3 \n\t" +" ld1 {v30.d}[1], [x11], x3 \n\t" +" ld1 {v31.d}[0], [x11], x3 \n\t" +" ld1 {v31.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v30.2d, v24.d[0] \n\t" +" fmla v3.2d, v31.2d, v24.d[0] \n\t" +" fmla v4.2d, v28.2d, v24.d[1] \n\t" +" fmla v5.2d, v29.2d, v24.d[1] \n\t" +" fmla v6.2d, v30.2d, v24.d[1] \n\t" +" fmla v7.2d, v31.2d, v24.d[1] \n\t" +" fmla v8.2d, v28.2d, v25.d[0] \n\t" +" fmla v9.2d, v29.2d, v25.d[0] \n\t" +" fmla v10.2d, v30.2d, v25.d[0] \n\t" +" fmla v11.2d, v31.2d, v25.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(12,13,14,15,x1,0,x6) +DLOADC_4V_R_FWD(16,17,18,19,x1,0,x6) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v12.2d, v0.2d, v4.2d \n\t" +" trn2 v13.2d, v0.2d, v4.2d \n\t" +" trn1 v14.2d, v1.2d, v5.2d \n\t" +" trn2 v15.2d, v1.2d, v5.2d \n\t" +" trn1 v16.2d, v2.2d, v6.2d \n\t" +" trn2 v17.2d, v2.2d, v6.2d \n\t" +" trn1 v18.2d, v3.2d, v7.2d \n\t" +" trn2 v19.2d, v3.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) +DLOADC_1V_1ELM_C_FWD(0,20,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(1,20,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(2,21,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(3,21,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(4,22,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(5,22,1,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(6,23,0,x1,0,x7) +DLOADC_1V_1ELM_C_FWD(7,23,1,x1,0,x7) +DSCALEA12V(12,13,14,15,16,17,18,19,8,9,10,11,0,1,2,3,4,5,6,7,20,21,22,23,31,0) +LABEL(ZERO_BETA_C) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +DSTOREC_1V_1ELM_C_FWD(12,8,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(13,8,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(14,9,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(15,9,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(16,10,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(17,10,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(18,11,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(19,11,1,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #3 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" madd x10, x2, x8, x10 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + a = a + m_iter * 3 * rs_a; + c = c + m_iter * 3 * rs_c; + for ( ; m_left > 0; m_left -= 2 ) + { + dim_t m_loc = ( m_left < 2 ) ? m_left : 2; + + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, m_loc, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + a += 2 * rs_a0; + c += 2 * rs_c0; + } +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c new file mode 100644 index 0000000000..e13dd668ea --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c @@ -0,0 +1,586 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +#define DGEMM_1X4X2_NKER_SUBLOOP(C0,C1,C2,C3,A,B0,B1,B2,B3) \ +" fmla v"#C0".2d, v"#A".2d, v"#B0".2d \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B1".2d \n\t" \ +" fmla v"#C2".2d, v"#A".2d, v"#B2".2d \n\t" \ +" fmla v"#C3".2d, v"#A".2d, v"#B3".2d \n\t" + +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) \ + /* Always load before forwarding to the next line. */ \ + DGEMM_1X4X2_NKER_SUBLOOP(C00,C01,C02,C03,A0,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A0,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C10,C11,C12,C13,A1,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A1,AELEMADDR,AELEMST) \ +" add "#AADDR", "#AADDR", #16 \n\t" \ +" mov "#AELEMADDR", "#AADDR" \n\t" \ + DGEMM_1X4X2_NKER_SUBLOOP(C20,C21,C22,C23,A2,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A2,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C30,C31,C32,C33,A3,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_load(A3,AELEMADDR,AELEMST) \ + \ + DGEMM_1X4X2_NKER_SUBLOOP(C40,C41,C42,C43,A0,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_1X4X2_NKER_SUBLOOP(C50,C51,C52,C53,A1,B0,B1,B2,B3) \ + DGEMM_LOAD1V_K_ ##LOADNEXT (A1,AELEMADDR,AELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +// For row-storage of C. +#define DLOADC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE12V(V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE4V(V8,V9,V10,V11,A,IDX) +#define DSCALEA12V(D0,D1,D2,D3,D4,D5,D6,D7,D8,D9,D10,D11,S0,S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,S11,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA4V(D8,D9,D10,D11,S8,S9,S10,S11,A,IDX) + +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +void bli_dgemmsup_rd_armv8a_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( m0 != 6 ) + { + if ( m0 < 6 ) + { + if ( m0 == 5 ) + { + // 3xk calls. + dim_t n = n0; + double *b_loc = b; + double *c_loc = c; + for ( ; n >= 4; n -= 4 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a, rs_a0, cs_a0, b_loc, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + b_loc += 4 * cs_b0; + c_loc += 4 * cs_c0; + } + if ( n > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n, k0, + alpha, a, rs_a0, cs_a0, b_loc, rs_b0, cs_b0, + beta, c_loc, rs_c0, cs_c0, data, cntx + ); + } + a += 3 * rs_a0; + c += 3 * rs_c0; + + // 2xk calls. + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + return; + } + else if ( m0 == 4 ) + { + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, 2, n_loc, k0, + alpha, a + 2 * rs_a0, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c + 2 * rs_c0, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + } + else if ( m0 == 3 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conja, conjb, 3, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 4 * cs_b0; + c += 4 * cs_c0; + } + if ( n0 > 0 ) + { + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + } + else // m0 == 2 or 1. + { + for ( ; n0 > 0; n0 -= 8 ) + { + dim_t n_loc = ( n0 < 8 ) ? n0 : 8; + bli_dgemmsup_rd_armv8a_int_2x8 + ( + conja, conjb, m0, n_loc, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b += 8 * cs_b0; + c += 8 * cs_c0; + } + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 4; + int64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) \ + DGEMM_6X4X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,A3,B0,B1,B2,B3,AADDR,AELEMADDR,AELEMST,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q28, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q29, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q30, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q31, [x11] \n\t" +// " add x11, x11, x3 \n\t" +" add x1, x1, #16 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q27, [x14] \n\t" +" add x14, x14, x2 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1,B2,B3) \ + DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,B2,B3,x0,x14,x2,load) \ + /* A already loaded and forwarded. Process B only. */ \ + "mov x11, x1 \n\t" \ + "ldr q28, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q29, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q30, [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q31, [x11] \n\t" \ + /*"add x11, x11, x3 \n\t"*/ \ + "add x1, x1, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,25,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X4X2_K_MKER_LOOP_PLAIN_LOC(26,27,24,25,28,29,30,31,x0,x14,x2,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" // Line 1. +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 2. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" faddp v6.2d, v12.2d, v13.2d \n\t" // Line 3. +" faddp v7.2d, v14.2d, v15.2d \n\t" +" faddp v8.2d, v16.2d, v17.2d \n\t" // Line 4. +" faddp v9.2d, v18.2d, v19.2d \n\t" +" faddp v10.2d, v20.2d, v21.2d \n\t" // Line 5. +" faddp v11.2d, v22.2d, v23.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" ld1 {v25.d}[1], [x14], x2 \n\t" +" ld1 {v26.d}[0], [x14], x2 \n\t" +" ld1 {v26.d}[1], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v28.2d, v24.d[1] \n\t" +" fmla v3.2d, v29.2d, v24.d[1] \n\t" +" fmla v4.2d, v28.2d, v25.d[0] \n\t" +" fmla v5.2d, v29.2d, v25.d[0] \n\t" +" fmla v6.2d, v28.2d, v25.d[1] \n\t" +" fmla v7.2d, v29.2d, v25.d[1] \n\t" +" fmla v8.2d, v28.2d, v26.d[0] \n\t" +" fmla v9.2d, v29.2d, v26.d[0] \n\t" +" fmla v10.2d, v28.2d, v26.d[1] \n\t" +" fmla v11.2d, v29.2d, v26.d[1] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) +DLOADC_2V_R_FWD(12,13,x1,0,x6) +DLOADC_2V_R_FWD(14,15,x1,0,x6) +DLOADC_2V_R_FWD(16,17,x1,0,x6) +DLOADC_2V_R_FWD(18,19,x1,0,x6) +DLOADC_2V_R_FWD(20,21,x1,0,x6) +DLOADC_2V_R_FWD(22,23,x1,0,x6) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) +DSTOREC_2V_R_FWD(6,7,x5,0,x6) +DSTOREC_2V_R_FWD(8,9,x5,0,x6) +DSTOREC_2V_R_FWD(10,11,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v12.2d, v0.2d, v2.2d \n\t" +" trn1 v13.2d, v4.2d, v6.2d \n\t" +" trn1 v14.2d, v8.2d, v10.2d \n\t" +" trn2 v15.2d, v0.2d, v2.2d \n\t" +" trn2 v16.2d, v4.2d, v6.2d \n\t" +" trn2 v17.2d, v8.2d, v10.2d \n\t" +" trn1 v18.2d, v1.2d, v3.2d \n\t" +" trn1 v19.2d, v5.2d, v7.2d \n\t" +" trn1 v20.2d, v9.2d, v11.2d \n\t" +" trn2 v21.2d, v1.2d, v3.2d \n\t" +" trn2 v22.2d, v5.2d, v7.2d \n\t" +" trn2 v23.2d, v9.2d, v11.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) +DLOADC_3V_C_FWD(0,1,2,x1,0,x7) +DLOADC_3V_C_FWD(3,4,5,x1,0,x7) +DLOADC_3V_C_FWD(6,7,8,x1,0,x7) +DLOADC_3V_C_FWD(9,10,11,x1,0,x7) +DSCALEA12V(12,13,14,15,16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,8,9,10,11,31,0) +LABEL(ZERO_BETA_C) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +DSTOREC_3V_C_FWD(12,13,14,x5,0,x7) +DSTOREC_3V_C_FWD(15,16,17,x5,0,x7) +DSTOREC_3V_C_FWD(18,19,20,x5,0,x7) +DSTOREC_3V_C_FWD(21,22,23,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #4 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" madd x10, x3, x8, x10 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + b = b + n_iter * 4 * cs_b; + c = c + n_iter * 4 * cs_c; + if ( n_left >= 3 ) + { + bli_dgemmsup_rd_armv8a_asm_6x3 + ( + conja, conjb, 6, 3, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + b = b + 3 * cs_b; + c = c + 3 * cs_c; + n_left -= 3; + } + + if ( n_left ) + { + // n_left < 3; + // + // Slice in rows. + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + a = a + 3 * rs_a; + c = c + 3 * rs_c; + + bli_dgemmsup_rd_armv8a_int_3x4 + ( + conja, conjb, 3, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c new file mode 100644 index 0000000000..16001a73ce --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8m.c @@ -0,0 +1,455 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ +---+ +---+ + * | 0 | | 2 | | 4 | | 6 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 1 | | 3 | | 5 | | 7 | + * +---+ +---+ +---+ +---+ + */ +#define DGEMM_4X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B1,BADDR,BSHIFT1) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B2,BADDR,BSHIFT2) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) + + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DLOADC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DLOAD2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DLOAD2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DSTORE2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DSTORE2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + + +/* + * 4x8 dgemmsup kernel with extending 1st dimension. + * + * Recommanded usage case: + * o 16 < (L1 cache latency) * (Num. FPU) < 25. + * o L1 cache has a bandwidth not too low (true in most cases). + * o (FMLA latency) * (Num. FPU) < 32 cycles (true in almost all cases). + */ +void bli_dgemmsup_rv_armv8a_asm_4x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( n0 == 8 ); + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 4; + int64_t m_left = m0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:23] <- A; Allowed latency: 48 cycles / # of FPUs. +// V[24:31] <- B; Allowed latency: 28 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_4X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v18.d}[0], [x14], x9 \n\t" +" ld1 {v18.d}[1], [x14], x9 \n\t" +" ld1 {v19.d}[0], [x14], x9 \n\t" +" ld1 {v19.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" ldr q28, [x1, #16*0] \n\t" +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1,B2,B3) \ + DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,x1,0,16*1,16*2,load) \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "add x0, x0, x2 \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(16,17,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(18,19,28,29,30,31) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(20,21,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(22,23,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,x1,0,16*1,16*2,load) +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(18,19,28,29,30,31,x1,0,16*1,16*2,load) +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(20,21,24,25,26,27,xzr,-1,-1,-1,noload) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(22,23,28,29,30,31,xzr,-1,-1,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" // Load A col. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B row. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,xzr,-1,-1,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fcmp d17, #0.0 \n\t" +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +// +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Column 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Column 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Column 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Column 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Column 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Column 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Column 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Column 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +" ld1r {v14.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v15.2d}, [x8] \n\t" +DSCALE8V(16,17,18,19,20,21,22,23,14,0) +DSCALE8V(24,25,26,27,28,29,30,31,14,0) +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +// +DSTOREC_4V_C_FWD(16,17,18,19,x5,0,x7) +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(28,29,30,31,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #4 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + a = a + m_iter * ps_a; + c = c + m_iter * 4 * rs_c; + if ( m_left ) + { + bli_dgemmsup_r_armv8a_ref2 + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c new file mode 100644 index 0000000000..43913cd38d --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d4x8n.c @@ -0,0 +1,458 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ +---+ +---+ + * | 0 | | 2 | | 4 | | 6 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 1 | | 3 | | 5 | | 7 | + * +---+ +---+ +---+ +---+ + */ +#define DGEMM_4X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B1,BADDR,BSHIFT1) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_LOAD1V_ ##LOADNEXT (B2,BADDR,BSHIFT2) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) + + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DLOADC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DLOAD2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DLOAD2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_4V_C_FWD(C00,C10,C01,C11,CADDR,CSHIFT,CSC) \ + DSTORE2V(C00,C10,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" \ + DSTORE2V(C01,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + + +/* + * 4x8 dgemmsup kernel with extending 2nd dimension. + * + * Recommanded usage case: + * o 16 < (L1 cache latency) * (Num. FPU) < 25. + * o L1 cache has a bandwidth not too low (true in most cases). + * o (FMLA latency) * (Num. FPU) < 32 cycles (true in almost all cases). + */ +void bli_dgemmsup_rv_armv8a_asm_4x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( m0 == 4 ); + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 8; + int64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x11, %[ps_b] \n\t" // Panel-skip of B. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_b +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:23] <- A; Allowed latency: 48 cycles / # of FPUs. +// V[24:31] <- B; Allowed latency: 28 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) \ + DGEMM_4X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,B0,B1,B2,B3,BADDR,BSHIFT0,BSHIFT1,BSHIFT2,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q24, [x1, #16*0] \n\t" // Load B first. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" ldr q28, [x1, #16*0] \n\t" +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v16.d}[0], [x14], x9 \n\t" // We want A to be kept in L1. +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v18.d}[0], [x14], x9 \n\t" +" ld1 {v18.d}[1], [x14], x9 \n\t" +" ld1 {v19.d}[0], [x14], x9 \n\t" +" ld1 {v19.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,B0,B1,B2,B3) \ + DGEMM_4X8_MKER_LOOP_PLAIN_LOC(A0,A1,B0,B1,B2,B3,x1,0,16*1,16*2,load) \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "add x0, x0, x2 \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(16,17,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(18,19,28,29,30,31) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(20,21,24,25,26,27) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC_FWD(22,23,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,x1,0,16*1,16*2,load) +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(18,19,28,29,30,31,x1,0,16*1,16*2,load) +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(20,21,24,25,26,27,xzr,-1,-1,-1,noload) +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(22,23,28,29,30,31,xzr,-1,-1,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q24, [x1, #16*0] \n\t" // Load B row. +" ldr q25, [x1, #16*1] \n\t" +" ldr q26, [x1, #16*2] \n\t" +" ldr q27, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" mov x14, x0 \n\t" // Load A col. +" ld1 {v16.d}[0], [x14], x9 \n\t" +" ld1 {v16.d}[1], [x14], x9 \n\t" +" ld1 {v17.d}[0], [x14], x9 \n\t" +" ld1 {v17.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_4X8_MKER_LOOP_PLAIN_LOC(16,17,24,25,26,27,xzr,-1,-1,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fcmp d17, #0.0 \n\t" +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) +DLOADC_4V_R_FWD(24,25,26,27,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_R) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +// +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Column 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Column 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Column 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Column 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Column 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Column 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Column 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Column 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +" ld1r {v14.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v15.2d}, [x8] \n\t" +DSCALE8V(16,17,18,19,20,21,22,23,14,0) +DSCALE8V(24,25,26,27,28,29,30,31,14,0) +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_C_FWD(0,1,2,3,x1,0,x7) +DLOADC_4V_C_FWD(4,5,6,7,x1,0,x7) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +#endif +// +DSTOREC_4V_C_FWD(16,17,18,19,x5,0,x7) +DSTOREC_4V_C_FWD(20,21,22,23,x5,0,x7) +DSTOREC_4V_C_FWD(24,25,26,27,x5,0,x7) +DSTOREC_4V_C_FWD(28,29,30,31,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_b] "m" (ps_b), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // TODO: Implement optimized kernel for this. + // + // Forward address. + b = b + n_iter * ps_b; + c = c + n_iter * 8 * cs_c; + if ( n_left ) + { + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, 4, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c new file mode 100644 index 0000000000..3100112d3f --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8m.c @@ -0,0 +1,575 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A1,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +#define DGEMM_LOAD1V_G_noload(V1,ADDR,ST) +#define DGEMM_LOAD1V_G_load(V1,ADDR,ST) \ +" ld1 {v"#V1".d}[0], ["#ADDR"], "#ST" \n\t" \ +" ld1 {v"#V1".d}[1], ["#ADDR"], "#ST" \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + + +/* + * 6x8 dgemmsup kernel with extending 1st dimension. + * + * Recommanded usage case: (L1 cache latency) * (Num. FPU) < 17 cycles. + * + * Calls 4x8 for edge cases. + */ +void bli_dgemmsup_rv_armv8a_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( n0 != 8 ) + { + if ( n0 < 8 ) + { + for ( ; n0 >= 4; n0 -= 4 ) + { + dgemmsup_ker_ft ukr_fp; + auxinfo_t data_d8xkm = *data; + if ( bli_auxinfo_ps_a( data ) == 6 * rs_a0 ) + { + // Use 8x4 Asm kernel for the unpacked case. + bli_auxinfo_set_ps_a( 8 * rs_a0, &data_d8xkm ); + ukr_fp = bli_dgemmsup_rv_armv8a_asm_8x4m; + } + else + { + // Cannot change dimension for m when A is packed. + ukr_fp = bli_dgemmsup_rv_armv8a_int_6x4mn; + } + + ukr_fp + ( + conja, conjb, m0, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d8xkm, cntx + ); + b += 4 * cs_b0; + c += 4 * cs_c0; + } + if ( n0 > 0 ) + { + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + } + else + { + assert( FALSE ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t m_iter = m0 / 6; + int64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +// DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v24.d}[0], [x14], x9 \n\t" +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x14,x9,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x14,x9,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA_R) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA_R) +" fcmp d25, #0.0 \n\t" +BEQ(ZERO_BETA_R_1) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(4,5,6,7,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_1) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +BEQ(ZERO_BETA_R_2) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,26,27,28,29,0,1,2,3,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose, +// do transposition in row-order. +" trn1 v24.2d, v0.2d, v4.2d \n\t" // Row 0-1. +" trn2 v25.2d, v0.2d, v4.2d \n\t" +" trn1 v26.2d, v1.2d, v5.2d \n\t" +" trn2 v27.2d, v1.2d, v5.2d \n\t" +" trn1 v28.2d, v2.2d, v6.2d \n\t" +" trn2 v29.2d, v2.2d, v6.2d \n\t" +" trn1 v30.2d, v3.2d, v7.2d \n\t" +" trn2 v31.2d, v3.2d, v7.2d \n\t" +" \n\t" +" trn1 v0.2d, v8.2d, v12.2d \n\t" // Row 2-3. +" trn2 v1.2d, v8.2d, v12.2d \n\t" +" trn1 v2.2d, v9.2d, v13.2d \n\t" +" trn2 v3.2d, v9.2d, v13.2d \n\t" +" trn1 v4.2d, v10.2d, v14.2d \n\t" +" trn2 v5.2d, v10.2d, v14.2d \n\t" +" trn1 v6.2d, v11.2d, v15.2d \n\t" +" trn2 v7.2d, v11.2d, v15.2d \n\t" +" \n\t" +" trn1 v8.2d, v16.2d, v20.2d \n\t" // Row 4-5. +" trn2 v9.2d, v16.2d, v20.2d \n\t" +" trn1 v10.2d, v17.2d, v21.2d \n\t" // AMARI +" trn2 v11.2d, v17.2d, v21.2d \n\t" // AMARI +" trn1 v12.2d, v18.2d, v22.2d \n\t" // AMARI +" trn2 v13.2d, v18.2d, v22.2d \n\t" // AMARI +" trn1 v14.2d, v19.2d, v23.2d \n\t" // AMARI +" trn2 v15.2d, v19.2d, v23.2d \n\t" // AMARI +" \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA_C) +DSCALE8V(24,25,26,27,28,29,30,31,16,0) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C_1) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(24,0,8,25,1,9,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_1) +DSTOREC_3V_C_FWD(24,0,8,x5,0,x7) +DSTOREC_3V_C_FWD(25,1,9,x5,0,x7) +BEQ(ZERO_BETA_C_2) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DLOADC_3V_C_FWD(24,0,8,x1,0,x7) +DLOADC_3V_C_FWD(25,1,9,x1,0,x7) +DSCALEA6V(26,2,10,27,3,11,18,19,20,21,22,23,17,0) +DSCALEA6V(28,4,12,29,5,13,24,0,8,25,1,9,17,0) +LABEL(ZERO_BETA_C_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +" fcmp d17, #0.0 \n\t" // Not the end. Reset branching reg. +#endif +DSTOREC_3V_C_FWD(26,2,10,x5,0,x7) +DSTOREC_3V_C_FWD(27,3,11,x5,0,x7) +BEQ(ZERO_BETA_C_3) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(30,6,14,31,7,15,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_3) +DSTOREC_3V_C_FWD(28,4,12,x5,0,x7) +DSTOREC_3V_C_FWD(29,5,13,x5,0,x7) +DSTOREC_3V_C_FWD(30,6,14,x5,0,x7) +DSTOREC_3V_C_FWD(31,7,15,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #6 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // Forward address. + a = a + m_iter * ps_a; + c = c + m_iter * 6 * rs_c; +#if 1 + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); +#else + if ( m_left >= 4 ) + { + // Calls 4x8m with only 1 outermost loop. + // As only 1 outermost loop is called, + // ps_a needs not being set here. + // + bli_dgemmsup_rv_armv8a_asm_4x8m + ( + conja, conjb, 4, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + m_left -= 4; + a = a + 4 * rs_a; + c = c + 4 * rs_c; + } + if ( m_left ) + { + bli_dgemmsup_r_armv8a_ref2 + ( + conja, conjb, m_left, 8, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } +#endif + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c new file mode 100644 index 0000000000..fb9357c11e --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d6x8n.c @@ -0,0 +1,539 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* Order of row-major DGEMM_6x8's execution in 2x2 blocks: + * + * +---+ +---+ +---+ +---+ + * | 0 | | 1 | | 6 | | 7 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 2 | | 3 | | 8 | | 9 | + * +---+ +---+ +---+ +---+ + * +---+ +---+ +---+ +---+ + * | 4 | | 5 | | 10| | 11| + * +---+ +---+ +---+ +---+ + * + */ +#define DGEMM_6X8_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,C30,C31,C32,C33,C40,C41,C42,C43,C50,C51,C52,C53,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C10,B0,A0) \ + DGEMM_2X2_NANOKERNEL(C01,C11,B1,A0) \ + DGEMM_2X2_NANOKERNEL(C20,C30,B0,A1) \ + DGEMM_2X2_NANOKERNEL(C21,C31,B1,A1) \ + DGEMM_2X2_NANOKERNEL(C40,C50,B0,A2) \ + DGEMM_2X2_NANOKERNEL(C41,C51,B1,A2) \ + DGEMM_LOAD2V_ ##LOADNEXT (B0,B1,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C12,B2,A0) \ + DGEMM_2X2_NANOKERNEL(C03,C13,B3,A0) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A0,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C22,C32,B2,A1) \ + DGEMM_2X2_NANOKERNEL(C23,C33,B3,A1) \ + DGEMM_LOAD1V_G_ ##LOADNEXT (A1,AELEMADDR,AELEMST) \ + DGEMM_2X2_NANOKERNEL(C42,C52,B2,A2) \ + DGEMM_2X2_NANOKERNEL(C43,C53,B3,A2) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DGEMM_LOAD2V_noload(V1,V2,ADDR,IMM) +#define DGEMM_LOAD2V_load(V1,V2,ADDR,IMM) \ + DGEMM_LOAD1V_load(V1,ADDR,IMM) \ + DGEMM_LOAD1V_load(V2,ADDR,IMM+16) + +#define DGEMM_LOAD1V_G_noload(V1,ADDR,ST) +#define DGEMM_LOAD1V_G_load(V1,ADDR,ST) \ +" ld1 {v"#V1".d}[0], ["#ADDR"], "#ST" \n\t" \ +" ld1 {v"#V1".d}[1], ["#ADDR"], "#ST" \n\t" + +// Prefetch C in the long direction. +#define DPRFMC_FWD(CADDR,DLONGC) \ +" prfm PLDL1KEEP, ["#CADDR"] \n\t" \ +" add "#CADDR", "#CADDR", "#DLONGC" \n\t" + +// For row-storage of C. +#define DLOADC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C0,C1,C2,C3,CADDR,CSHIFT,RSC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + + +/* + * 6x8 dgemmsup kernel with extending 2nd dimension. + * + * Recommanded usage case: (L1 cache latency) * (Num. FPU) < 17 cycles. + * + * Calls 4x8n for edge cases. + */ +void bli_dgemmsup_rv_armv8a_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + if ( m0 != 6 ) + { + // 5 = 4 + 1; + // 4; + // + while ( m0 >= 4 ) + { + bli_dgemmsup_rv_armv8a_asm_4x8n + ( + conja, conjb, 4, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + m0 -= 4; + a += 4 * rs_a0; + c += 4 * rs_c0; + } + + // 3, 2, 1; + // + if ( m0 > 0 ) + { + bli_dgemmsup_rv_armv8a_int_3x8mn + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + } + return; + } + + // LLVM has very bad routing ability for inline asm. + // Limit number of registers in case of Clang compilation. +#ifndef __clang__ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); +#endif + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + int64_t n_iter = n0 / 8; + int64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( n_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[b] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[n_iter] \n\t" +" ldr x11, %[ps_b] \n\t" // Panel-skip of B. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_b +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" mov x1, x5 \n\t" +" cmp x7, #8 \n\t" // Prefetch column-strided C. +BEQ(C_PREFETCH_COLS) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +DPRFMC_FWD(x1,x6) +BRANCH(C_PREFETCH_END) +LABEL(C_PREFETCH_COLS) +// This prefetch will not cover further mker perts. Skip. +// +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +// DPRFMC_FWD(x1,x7) +LABEL(C_PREFETCH_END) +// +// Millikernel. +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x1, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x0, %[a] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:23] <- C +// V[24:27] <- A +// V[28:31] <- B +// Under this scheme, the following is defined: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_6X8_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,A0,A1,A2,B0,B1,B2,B3,AELEMADDR,AELEMST,BADDR,BSHIFT,LOADNEXT) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" ldr q28, [x1, #16*0] \n\t" // Load B first. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ld1 {v24.d}[0], [x14], x9 \n\t" // We want A to be kept in L1. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR8V(16,17,18,19,20,21,22,23) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_6X8_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3,x14,x9,x1,0,load) \ + "add x0, x0, x2 \n\t" \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ldr q"#B2", [x1, #16*2] \n\t" \ + "ldr q"#B3", [x1, #16*3] \n\t" \ + "add x1, x1, x3 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(27,24,25,28,29,30,31) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(26,27,24,28,29,30,31) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC_FWD(25,26,27,28,29,30,31) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(26,27,24,28,29,30,31,x14,x9,x1,0,load) +" add x0, x0, x2 \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(25,26,27,28,29,30,31,xzr,-1,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" ldr q28, [x1, #16*0] \n\t" // Load B row. +" ldr q29, [x1, #16*1] \n\t" +" ldr q30, [x1, #16*2] \n\t" +" ldr q31, [x1, #16*3] \n\t" +" add x1, x1, x3 \n\t" +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_6X8_MKER_LOOP_PLAIN_LOC(24,25,26,28,29,30,31,xzr,-1,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" ld1r {v24.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v25.2d}, [x8] \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA_R) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA_R) +" fcmp d25, #0.0 \n\t" +BEQ(ZERO_BETA_R_1) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DSCALEA4V(4,5,6,7,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_1) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +BEQ(ZERO_BETA_R_2) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(8,9,10,11,12,13,14,15,26,27,28,29,0,1,2,3,25,0) +DLOADC_4V_R_FWD(26,27,28,29,x1,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_R) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_R) +#endif +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +// In-register transpose, +// do transposition in row-order. +" trn1 v24.2d, v0.2d, v4.2d \n\t" // Row 0-1. +" trn2 v25.2d, v0.2d, v4.2d \n\t" +" trn1 v26.2d, v1.2d, v5.2d \n\t" +" trn2 v27.2d, v1.2d, v5.2d \n\t" +" trn1 v28.2d, v2.2d, v6.2d \n\t" +" trn2 v29.2d, v2.2d, v6.2d \n\t" +" trn1 v30.2d, v3.2d, v7.2d \n\t" +" trn2 v31.2d, v3.2d, v7.2d \n\t" +" \n\t" +" trn1 v0.2d, v8.2d, v12.2d \n\t" // Row 2-3. +" trn2 v1.2d, v8.2d, v12.2d \n\t" +" trn1 v2.2d, v9.2d, v13.2d \n\t" +" trn2 v3.2d, v9.2d, v13.2d \n\t" +" trn1 v4.2d, v10.2d, v14.2d \n\t" +" trn2 v5.2d, v10.2d, v14.2d \n\t" +" trn1 v6.2d, v11.2d, v15.2d \n\t" +" trn2 v7.2d, v11.2d, v15.2d \n\t" +" \n\t" +" trn1 v8.2d, v16.2d, v20.2d \n\t" // Row 4-5. +" trn2 v9.2d, v16.2d, v20.2d \n\t" +" trn1 v10.2d, v17.2d, v21.2d \n\t" // AMARI +" trn2 v11.2d, v17.2d, v21.2d \n\t" // AMARI +" trn1 v12.2d, v18.2d, v22.2d \n\t" // AMARI +" trn2 v13.2d, v18.2d, v22.2d \n\t" // AMARI +" trn1 v14.2d, v19.2d, v23.2d \n\t" // AMARI +" trn2 v15.2d, v19.2d, v23.2d \n\t" // AMARI +" \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta. +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA_C) +DSCALE8V(24,25,26,27,28,29,30,31,16,0) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C_1) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(24,0,8,25,1,9,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_1) +DSTOREC_3V_C_FWD(24,0,8,x5,0,x7) +DSTOREC_3V_C_FWD(25,1,9,x5,0,x7) +BEQ(ZERO_BETA_C_2) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DLOADC_3V_C_FWD(24,0,8,x1,0,x7) +DLOADC_3V_C_FWD(25,1,9,x1,0,x7) +DSCALEA6V(26,2,10,27,3,11,18,19,20,21,22,23,17,0) +DSCALEA6V(28,4,12,29,5,13,24,0,8,25,1,9,17,0) +LABEL(ZERO_BETA_C_2) +#ifndef __clang__ +" cmp x12, #1 \n\t" +BRANCH(PRFM_END_C) +" prfm PLDL1KEEP, [%[a_next], #16*0] \n\t" +" prfm PLDL1KEEP, [%[a_next], #16*1] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*0] \n\t" +" prfm PLDL1STRM, [%[b_next], #16*1] \n\t" +LABEL(PRFM_END_C) +" fcmp d17, #0.0 \n\t" // Not the end. Reset branching reg. +#endif +DSTOREC_3V_C_FWD(26,2,10,x5,0,x7) +DSTOREC_3V_C_FWD(27,3,11,x5,0,x7) +BEQ(ZERO_BETA_C_3) +DLOADC_3V_C_FWD(18,19,20,x1,0,x7) +DLOADC_3V_C_FWD(21,22,23,x1,0,x7) +DSCALEA6V(30,6,14,31,7,15,18,19,20,21,22,23,17,0) +LABEL(ZERO_BETA_C_3) +DSTOREC_3V_C_FWD(28,4,12,x5,0,x7) +DSTOREC_3V_C_FWD(29,5,13,x5,0,x7) +DSTOREC_3V_C_FWD(30,6,14,x5,0,x7) +DSTOREC_3V_C_FWD(31,7,15,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x7, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward B's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_b] "m" (ps_b), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + // In Clang, even "m"-passed parameter takes 1 register. + // Have to disable prefetching to pass compilation. +#ifndef __clang__ + [a_next] "r" (a_next), + [b_next] "r" (b_next), +#endif + [n_iter] "m" (n_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + // Forward address. + b = b + n_iter * ps_b; + c = c + n_iter * 8 * cs_c; + if ( n_left ) + { + // Set panel stride to unpacked mode. + // Only 1 millikernel w.r.t. 6x8 is executed. + auxinfo_t data_d6x4mn = *data; + bli_auxinfo_set_ps_b( 4 * cs_b0, &data_d6x4mn ); + // + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, 6, n_left, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_d6x4mn, cntx + ); + } + +} + diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c new file mode 100644 index 0000000000..5b0e9b062f --- /dev/null +++ b/kernels/armv8a/3/sup/bli_gemmsup_rv_armv8a_asm_d8x4m.c @@ -0,0 +1,431 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "assert.h" + +GEMMSUP_KER_PROT( double, d, gemmsup_r_armv8a_ref2 ) + +// Label locality & misc. +#include "../armv8a_asm_utils.h" + +// Nanokernel operations. +#include "../armv8a_asm_d2x2.h" + +/* + * +---+ +---+ + * | 0 | | 4 | + * +---+ +---+ + * +---+ +---+ + * | 1 | | 5 | + * +---+ +---+ + * +---+ +---+ + * | 2 | | 6 | + * +---+ +---+ + * +---+ +---+ + * | 3 | | 7 | + * +---+ +---+ + * + */ +#define DGEMM_8X4_MKER_LOOP_PLAIN(C00,C10,C20,C30,C01,C11,C21,C31,C02,C12,C22,C32,C03,C13,C23,C33,A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_2X2_NANOKERNEL(C00,C01,A0,B0) \ + DGEMM_2X2_NANOKERNEL(C10,C11,A1,B0) \ + DGEMM_2X2_NANOKERNEL(C20,C21,A2,B0) \ + DGEMM_2X2_NANOKERNEL(C30,C31,A3,B0) \ + DGEMM_LOAD1V_ ##LOADNEXT (B0,BADDR,BSHIFT) \ + DGEMM_2X2_NANOKERNEL(C02,C03,A0,B1) \ + DGEMM_2X2_NANOKERNEL(C12,C13,A1,B1) \ + DGEMM_2X2_NANOKERNEL(C22,C23,A2,B1) \ + DGEMM_2X2_NANOKERNEL(C32,C33,A3,B1) + +// Interleaving load or not. +#define DGEMM_LOAD1V_noload(V1,ADDR,IMM) +#define DGEMM_LOAD1V_load(V1,ADDR,IMM) \ +" ldr q"#V1", ["#ADDR", #"#IMM"] \n\t" + +#define DLOADC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DLOAD4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" +#define DSTOREC_4V_C_FWD(C0,C1,C2,C3,CADDR,CSHIFT,LDC) \ + DSTORE4V(C0,C1,C2,C3,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#LDC" \n\t" + +#define DLOADC_4V_R_FWD(C00,C01,C10,C11,CADDR,CSHIFT,RSC) \ + DLOAD2V(C00,C01,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" \ + DLOAD2V(C10,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_4V_R_FWD(C00,C01,C10,C11,CADDR,CSHIFT,RSC) \ + DSTORE2V(C00,C01,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" \ + DSTORE2V(C10,C11,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +/* + * 8x4 kernel for dgemmsup. + * + * R-dimension too short. + * Not recommanded for use. + */ +void bli_dgemmsup_rv_armv8a_asm_8x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Fixme: This uker has no dispatching for unalighed sizes. + // Currently it only serves as a dispatch target for other kernels + // and cannot be registered in configurations. + assert( n0 == 4 ); + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + uint64_t ps_a = bli_auxinfo_ps_a( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 6; + uint64_t k_left = k0 % 6; + + int64_t m_iter = m0 / 8; + int64_t m_left = m0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t cs_b = cs_b0; + assert( cs_b0 == 1 ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + __asm__ volatile + ( +" ldr x10, %[a] \n\t" +" ldr x13, %[c] \n\t" +" ldr x12, %[m_iter] \n\t" +" ldr x11, %[ps_a] \n\t" // Panel-skip of A. +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x9, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[rs_b] \n\t" // Row-skip of B. +" \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x11, x11, #3 \n\t" // ps_a +" lsl x9, x9, #3 \n\t" // rs_a +" lsl x2, x2, #3 \n\t" // cs_a +" lsl x3, x3, #3 \n\t" // rs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +LABEL(MILLIKER_MLOOP) +" \n\t" +" mov x0, x10 \n\t" // Parameters to be reloaded +" mov x5, x13 \n\t" // within each millikernel loop. +" ldr x1, %[b] \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:15] <- C +// V[16:19] <- B; Allowed latency: 24 cycles / # of FPUs. +// V[20:31] <- A; Allowed latency: 32 cycles / # of FPUs. +// Under this scheme, the following is defined: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) \ + DGEMM_8X4_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,A0,A1,A2,A3,B0,B1,BADDR,BSHIFT,LOADNEXT) +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v24.d}[0], [x14], x9 \n\t" +" ld1 {v24.d}[1], [x14], x9 \n\t" +" ld1 {v25.d}[0], [x14], x9 \n\t" +" ld1 {v25.d}[1], [x14], x9 \n\t" +" ld1 {v26.d}[0], [x14], x9 \n\t" +" ld1 {v26.d}[1], [x14], x9 \n\t" +" ld1 {v27.d}[0], [x14], x9 \n\t" +" ld1 {v27.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" mov x14, x0 \n\t" +" ld1 {v28.d}[0], [x14], x9 \n\t" +" ld1 {v28.d}[1], [x14], x9 \n\t" +" ld1 {v29.d}[0], [x14], x9 \n\t" +" ld1 {v29.d}[1], [x14], x9 \n\t" +" ld1 {v30.d}[0], [x14], x9 \n\t" +" ld1 {v30.d}[1], [x14], x9 \n\t" +" ld1 {v31.d}[0], [x14], x9 \n\t" +" ld1 {v31.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" \n\t" +" ldr q16, [x1, #16*0] \n\t" +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" ldr q18, [x1, #16*0] \n\t" +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,B0,B1) \ + DGEMM_8X4_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,B0,B1,x1,0,load) \ + "mov x14, x0 \n\t" \ + "ld1 {v"#A0".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A0".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A1".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A2".d}[1], [x14], x9 \n\t" \ + "ld1 {v"#A3".d}[0], [x14], x9 \n\t" \ + "ld1 {v"#A3".d}[1], [x14], x9 \n\t" \ + "ldr q"#B1", [x1, #16*1] \n\t" \ + "add x1, x1, x3 \n\t" \ + "add x0, x0, x2 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,22,23,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,16,17) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(20,21,22,23,18,19) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,16,17) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(20,21,22,23,18,19,x1,0,load) +" ldr q19, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(24,25,26,27,16,17,xzr,-1,noload) +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(28,29,30,31,18,19,xzr,-1,noload) +// +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" +" ld1 {v20.d}[0], [x14], x9 \n\t" // Load A col. +" ld1 {v20.d}[1], [x14], x9 \n\t" +" ld1 {v21.d}[0], [x14], x9 \n\t" +" ld1 {v21.d}[1], [x14], x9 \n\t" +" ld1 {v22.d}[0], [x14], x9 \n\t" +" ld1 {v22.d}[1], [x14], x9 \n\t" +" ld1 {v23.d}[0], [x14], x9 \n\t" +" ld1 {v23.d}[1], [x14], x9 \n\t" +" add x0, x0, x2 \n\t" +" ldr q16, [x1, #16*0] \n\t" // Load B col. +" ldr q17, [x1, #16*1] \n\t" +" add x1, x1, x3 \n\t" +" sub x8, x8, #1 \n\t" +DGEMM_8X4_MKER_LOOP_PLAIN_LOC(20,21,22,23,16,17,xzr,-1,noload) +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v16.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v17.2d}, [x8] \n\t" +" fmov d18, #1.0 \n\t" +" fcmp d16, d18 \n\t" +BEQ(UNIT_ALPHA) +DSCALE8V(0,1,2,3,4,5,6,7,16,0) +DSCALE8V(8,9,10,11,12,13,14,15,16,0) +LABEL(UNIT_ALPHA) +" \n\t" +" mov x1, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #8 \n\t" // Check for row-storage. +BNE(WRITE_MEM_R) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" fcmp d17, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_4V_C_FWD(20,21,22,23,x1,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x1,0,x7) +DSCALEA8V(0,1,2,3,4,5,6,7,20,21,22,23,24,25,26,27,17,0) +// +DLOADC_4V_C_FWD(20,21,22,23,x1,0,x7) +DLOADC_4V_C_FWD(24,25,26,27,x1,0,x7) +DSCALEA8V(8,9,10,11,12,13,14,15,20,21,22,23,24,25,26,27,17,0) +LABEL(ZERO_BETA_C) +// +DSTOREC_4V_C_FWD(0,1,2,3,x5,0,x7) +DSTOREC_4V_C_FWD(4,5,6,7,x5,0,x7) +DSTOREC_4V_C_FWD(8,9,10,11,x5,0,x7) +DSTOREC_4V_C_FWD(12,13,14,15,x5,0,x7) +BRANCH(END_WRITE_MEM) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +// In-register transpose. +" trn1 v16.2d, v0.2d, v4.2d \n\t" // Row 0. +" trn1 v17.2d, v8.2d, v12.2d \n\t" +" trn2 v18.2d, v0.2d, v4.2d \n\t" // Row 1. +" trn2 v19.2d, v8.2d, v12.2d \n\t" +" trn1 v20.2d, v1.2d, v5.2d \n\t" // Row 2. +" trn1 v21.2d, v9.2d, v13.2d \n\t" +" trn2 v22.2d, v1.2d, v5.2d \n\t" // Row 3. +" trn2 v23.2d, v9.2d, v13.2d \n\t" +" trn1 v24.2d, v2.2d, v6.2d \n\t" // Row 4. +" trn1 v25.2d, v10.2d, v14.2d \n\t" +" trn2 v26.2d, v2.2d, v6.2d \n\t" // Row 5. +" trn2 v27.2d, v10.2d, v14.2d \n\t" +" trn1 v28.2d, v3.2d, v7.2d \n\t" // Row 6. +" trn1 v29.2d, v11.2d, v15.2d \n\t" +" trn2 v30.2d, v3.2d, v7.2d \n\t" // Row 7. +" trn2 v31.2d, v11.2d, v15.2d \n\t" +// " ld1r {v14.2d}, [x4] \n\t" // Reload alpha & beta (value). +" ld1r {v15.2d}, [x8] \n\t" +" fcmp d15, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x1,0,x6) +DSCALEA8V(16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,15,0) +// +DLOADC_4V_R_FWD(0,1,2,3,x1,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x1,0,x6) +DSCALEA8V(24,25,26,27,28,29,30,31,0,1,2,3,4,5,6,7,15,0) +LABEL(ZERO_BETA_R) +// +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +DSTOREC_4V_R_FWD(24,25,26,27,x5,0,x6) +DSTOREC_4V_R_FWD(28,29,30,31,x5,0,x6) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +" \n\t" +" subs x12, x12, #1 \n\t" +BEQ(END_EXEC) +" \n\t" +" mov x8, #8 \n\t" +" madd x13, x6, x8, x13 \n\t" // Forward C's base address to the next logic panel. +" add x10, x10, x11 \n\t" // Forward A's base address to the next logic panel. +BRANCH(MILLIKER_MLOOP) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_iter] "m" (m_iter), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +consider_edge_cases: + a = a + m_iter * ps_a; + c = c + m_iter * 8 * rs_c; + // Edge case is within 1 millikernel loop of THIS kernel. + // Regarding the 6x?m kernel, the panel stride should be always local. + auxinfo_t data_6xkm = *data; + bli_auxinfo_set_ps_a( 6 * rs_a, &data_6xkm ); + if ( m_left ) + { + bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conja, conjb, m_left, 4, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, &data_6xkm, cntx + ); + } + + // Issue prefetch instructions only after + // execution is done. + __asm__ + ( +" mov x0, %[a_next] \n\t" +" mov x1, %[b_next] \n\t" +" prfm PLDL1STRM, [x0, #16*0] \n\t" +" prfm PLDL1STRM, [x0, #16*1] \n\t" +" prfm PLDL1STRM, [x0, #16*2] \n\t" +" prfm PLDL1KEEP, [x1, #16*0] \n\t" +" prfm PLDL1KEEP, [x1, #16*1] \n\t" +" prfm PLDL1KEEP, [x1, #16*2] \n\t" +: +: [a_next] "r" (a_next), + [b_next] "r" (b_next) +: "x0", "x1" + ); +} + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c new file mode 100644 index 0000000000..84c7c4a7d2 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c @@ -0,0 +1,309 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary fixed-size gemmsup. + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "../../armv8a_asm_utils.h" + +#define DGEMM_3X1X2_NKER_SUBLOOP(C0,C1,C2,A0,A1,A2,B) \ +" fmla v"#C0".2d, v"#A0".2d, v"#B".2d \n\t" \ +" fmla v"#C1".2d, v"#A1".2d, v"#B".2d \n\t" \ +" fmla v"#C2".2d, v"#A2".2d, v"#B".2d \n\t" + +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C03,C10,C11,C12,C13,C20,C21,C22,C23,A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X1X2_NKER_SUBLOOP(C00,C10,C20,A0,A1,A2,B0) \ + DGEMM_3X1X2_NKER_SUBLOOP(C01,C11,C21,A0,A1,A2,B1) \ + DGEMM_3X1X2_NKER_SUBLOOP(C02,C12,C22,A0,A1,A2,B2) \ + DGEMM_3X1X2_NKER_SUBLOOP(C03,C13,C23,A0,A1,A2,B3) + +// For row-storage of C. +#define DLOADC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_2V_R_FWD(C0,C1,CADDR,CSHIFT,RSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_1V_1ELM_C_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,CSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE6V(V0,V1,V2,V3,V4,V5,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE2V(V4,V5,A,IDX) +#define DSCALEA6V(D0,D1,D2,D3,D4,D5,S0,S1,S2,S3,S4,S5,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA2V(D4,D5,S4,S5,A,IDX) + +void bli_dgemmsup_rd_armv8a_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 == 3 ); + assert( n0 == 4 ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:11] <- C +// V[12:14] <- A +// V[16:19] <- B +// Under this scheme, the following is defined: +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X4X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,A0,A1,A2,B0,B1,B2,B3) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q16, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q17, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q18, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q19, [x11] \n\t" +" add x1, x1, #16 \n\t" +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q12, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q13, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q14, [x14] \n\t" +" add x0, x0, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR4V(8,9,10,11) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,B0,B1,B2,B3) \ + DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,B0,B1,B2,B3) \ + "mov x11, x1 \n\t" \ + "ldr q"#B0", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B1", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B2", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B3", [x11] \n\t" \ + "add x1, x1, #16 \n\t" \ + "mov x14, x0 \n\t" \ + "ldr q"#A0", [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q"#A1", [x14] \n\t" \ + "add x14, x14, x2 \n\t" \ + "ldr q"#A2", [x14] \n\t" \ + "add x0, x0, #16 \n\t" +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(12,13,14,16,17,18,19) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC_FWD(12,13,14,16,17,18,19) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_3X4X2_K_MKER_LOOP_PLAIN_LOC(12,13,14,16,17,18,19) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v1.2d \n\t" // Line 0. +" faddp v1.2d, v2.2d, v3.2d \n\t" +" faddp v2.2d, v4.2d, v5.2d \n\t" // Line 1. +" faddp v3.2d, v6.2d, v7.2d \n\t" +" faddp v4.2d, v8.2d, v9.2d \n\t" // Line 2. +" faddp v5.2d, v10.2d, v11.2d \n\t" +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" ld1 {v29.d}[1], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" fmla v0.2d, v28.2d, v24.d[0] \n\t" +" fmla v1.2d, v29.2d, v24.d[0] \n\t" +" fmla v2.2d, v28.2d, v24.d[1] \n\t" +" fmla v3.2d, v29.2d, v24.d[1] \n\t" +" fmla v4.2d, v28.2d, v25.d[0] \n\t" +" fmla v5.2d, v29.2d, v25.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +DSCALE6V(0,1,2,3,4,5,30,0) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_2V_R_FWD(12,13,x9,0,x6) +DLOADC_2V_R_FWD(14,15,x9,0,x6) +DLOADC_2V_R_FWD(16,17,x9,0,x6) +DSCALEA6V(0,1,2,3,4,5,12,13,14,15,16,17,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" trn1 v6.2d, v0.2d, v2.2d \n\t" +" trn2 v7.2d, v0.2d, v2.2d \n\t" +" trn1 v8.2d, v1.2d, v3.2d \n\t" +" trn2 v9.2d, v1.2d, v3.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_1V_1ELM_C_FWD(12,20,0,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(13,20,1,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(14,21,0,x9,0,x7) +DLOADC_1V_1ELM_C_FWD(15,21,1,x9,0,x7) +DSCALEA6V(6,7,8,9,4,5,12,13,14,15,20,21,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_1V_1ELM_C_FWD(6,4,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(7,4,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(8,5,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(9,5,1,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +} + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c new file mode 100644 index 0000000000..abbb6fb4d9 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c @@ -0,0 +1,359 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary fixed-size gemmsup. + +#include "blis.h" +#include "assert.h" + +// Label locality & misc. +#include "../../armv8a_asm_utils.h" + +#define DGEMM_1X3X2_NKER_SUBLOOP(C0,C1,C2,A,B0,B1,B2) \ +" fmla v"#C0".2d, v"#A".2d, v"#B0".2d \n\t" \ +" fmla v"#C1".2d, v"#A".2d, v"#B1".2d \n\t" \ +" fmla v"#C2".2d, v"#A".2d, v"#B2".2d \n\t" + +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN(C00,C01,C02,C10,C11,C12,C20,C21,C22,C30,C31,C32,C40,C41,C42,C50,C51,C52,A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) \ + DGEMM_1X3X2_NKER_SUBLOOP(C00,C01,C02,A0,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A0,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C10,C11,C12,A1,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A1,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C20,C21,C22,A2,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A2,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C30,C31,C32,A3,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD0 (A3,AELEMADDR,AELEMST) \ + DGEMM_FWDA_K_ ##LOAD0 (AADDR) \ +" mov "#AELEMADDR", "#AADDR" \n\t" \ + DGEMM_1X3X2_NKER_SUBLOOP(C40,C41,C42,A4,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD1 (A4,AELEMADDR,AELEMST) \ + DGEMM_1X3X2_NKER_SUBLOOP(C50,C51,C52,A5,B0,B1,B2) \ + DGEMM_LOAD1V_K_ ##LOAD1 (A5,AELEMADDR,AELEMST) + +#define DGEMM_LOAD1V_K_noload(V,ELEMADDR,ELEMST) +#define DGEMM_LOAD1V_K_load(V,ELEMADDR,ELEMST) \ +" ldr q"#V", [ "#ELEMADDR" ] \n\t" \ +" add "#ELEMADDR", "#ELEMADDR", "#ELEMST" \n\t" + +#define DGEMM_FWDA_K_noload(ADDR) +#define DGEMM_FWDA_K_load(ADDR) \ +" add "#ADDR", "#ADDR", #16 \n\t" + +// For row-storage of C. +#define DLOADC_1V_1ELM_R_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,RSC) \ + DLOAD1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" ld1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" +#define DSTOREC_1V_1ELM_R_FWD(C0,CSCALAR,CIDX,CADDR,CSHIFT,RSC) \ + DSTORE1V(C0,CADDR,CSHIFT) \ +" add "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" st1 {v"#CSCALAR".d}["#CIDX"], ["#CADDR"] \n\t" \ +" sub "#CADDR", "#CADDR", #"#CSHIFT"+16 \n\t" \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +// For column-storage of C. +#define DLOADC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DLOAD2V(C0,C1,CADDR,CSHIFT) \ + DLOAD1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" +#define DSTOREC_3V_C_FWD(C0,C1,C2,CADDR,CSHIFT,CSC) \ + DSTORE2V(C0,C1,CADDR,CSHIFT) \ + DSTORE1V(C2,CADDR,CSHIFT+32) \ +" add "#CADDR", "#CADDR", "#CSC" \n\t" + +#define DSCALE9V(V0,V1,V2,V3,V4,V5,V6,V7,V8,A,IDX) \ + DSCALE4V(V0,V1,V2,V3,A,IDX) \ + DSCALE4V(V4,V5,V6,V7,A,IDX) \ + DSCALE1V(V8,A,IDX) +#define DSCALEA9V(D0,D1,D2,D3,D4,D5,D6,D7,D8,S0,S1,S2,S3,S4,S5,S6,S7,S8,A,IDX) \ + DSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \ + DSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX) \ + DSCALEA1V(D8,S8,A,IDX) + + +void bli_dgemmsup_rd_armv8a_asm_6x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 == 6 ); + assert( n0 == 3 ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 8; + uint64_t k_left = k0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + assert( cs_a0 == 1 ); + assert( rs_b0 == 1 ); + + __asm__ volatile + ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" ldr x2, %[rs_a] \n\t" // Row-skip of A. +" ldr x3, %[cs_b] \n\t" // Column-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +" \n\t" +" \n\t" // Multiply some address skips by sizeof(double). +" lsl x2, x2, #3 \n\t" // rs_a +" lsl x3, x3, #3 \n\t" // cs_b +" lsl x6, x6, #3 \n\t" // rs_c +" lsl x7, x7, #3 \n\t" // cs_c +" \n\t" +" ldr x4, %[k_mker] \n\t" +" ldr x8, %[k_left] \n\t" +" \n\t" +// Storage scheme: +// V[ 0:17] <- C +// V[18:23] <- B +// V[24:31] <- A +// Under this scheme, the following is defined: +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) \ + DGEMM_6X3X2_K_MKER_LOOP_PLAIN(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,A0,A1,A2,A3,A4,A5,B0,B1,B2,AADDR,AELEMADDR,AELEMST,LOAD0,LOAD1) +// Load from memory. +LABEL(LOAD_ABC) +" \n\t" // No-microkernel early return is a must +" cmp x4, #0 \n\t" // to avoid out-of-boundary read. +BEQ(CLEAR_CCOLS) +" \n\t" +" mov x14, x0 \n\t" // Load A. +" ldr q24, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q25, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q26, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q27, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q28, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q29, [x14] \n\t" +" add x0, x0, #16 \n\t" +" mov x14, x0 \n\t" +" ldr q30, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ldr q31, [x14] \n\t" +" add x14, x14, x2 \n\t" +" \n\t" +" mov x11, x1 \n\t" // Load B. +" ldr q18, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q19, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q20, [x11] \n\t" +" add x1, x1, #16 \n\t" +" mov x11, x1 \n\t" +" ldr q21, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q22, [x11] \n\t" +" add x11, x11, x3 \n\t" +" ldr q23, [x11] \n\t" +" add x1, x1, #16 \n\t" +LABEL(CLEAR_CCOLS) +CLEAR8V(0,1,2,3,4,5,6,7) +CLEAR8V(8,9,10,11,12,13,14,15) +CLEAR2V(16,17) +// No-microkernel early return, once again. +BEQ(K_LEFT_LOOP) +// +// Microkernel is defined here as: +#define DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(A0,A1,A2,A3,A4,A5,B0,B1,B2) \ + DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(A0,A1,A2,A3,A4,A5,B0,B1,B2,x0,x14,x2,load,load) \ + "mov x11, x1 \n\t" \ + "ldr q"#B0", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B1", [x11] \n\t" \ + "add x11, x11, x3 \n\t" \ + "ldr q"#B2", [x11] \n\t" \ + "add x1, x1, #16 \n\t" \ +// Start microkernel loop. +LABEL(K_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(24,25,26,27,28,29,18,19,20) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(30,31,24,25,26,27,21,22,23) +" \n\t" // Decrease counter before final replica. +" subs x4, x4, #1 \n\t" // Branch early to avoid reading excess mem. +BEQ(FIN_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(28,29,30,31,24,25,18,19,20) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC_FWD(26,27,28,29,30,31,21,22,23) +BRANCH(K_MKER_LOOP) +// +// Final microkernel loop. +LABEL(FIN_MKER_LOOP) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(28,29,30,31,24,25,18,19,20,x0,x14,x2,load,noload) +DGEMM_6X3X2_K_MKER_LOOP_PLAIN_LOC(26,27,28,29,30,31,21,22,23,xzr,xzr,xzr,noload,noload) +// +// If major kernel is executed, +// an additional depth-summation is required. +" faddp v0.2d, v0.2d, v3.2d \n\t" // Column 0 Prt 0. +" faddp v1.2d, v1.2d, v4.2d \n\t" // Column 1 Prt 0. +" faddp v2.2d, v2.2d, v5.2d \n\t" // Column 2 Prt 0. +" faddp v3.2d, v6.2d, v9.2d \n\t" // Column 0 Prt 1. +" faddp v4.2d, v7.2d, v10.2d \n\t" // Column 1 Prt 1. +" faddp v5.2d, v8.2d, v11.2d \n\t" // Column 2 Prt 1. +" faddp v6.2d, v12.2d, v15.2d \n\t" // Column 0 Prt 2. +" faddp v7.2d, v13.2d, v16.2d \n\t" // Column 1 Prt 2. +" faddp v8.2d, v14.2d, v17.2d \n\t" // Column 2 Prt 2. +" \n\t" +// Loops left behind microkernels. +LABEL(K_LEFT_LOOP) +" cmp x8, #0 \n\t" // End of exec. +BEQ(WRITE_MEM_PREP) +" mov x14, x0 \n\t" // Load A column. +" ld1 {v24.d}[0], [x14], x2 \n\t" +" ld1 {v24.d}[1], [x14], x2 \n\t" +" ld1 {v25.d}[0], [x14], x2 \n\t" +" ld1 {v25.d}[1], [x14], x2 \n\t" +" ld1 {v26.d}[0], [x14], x2 \n\t" +" ld1 {v26.d}[1], [x14], x2 \n\t" +" add x0, x0, #8 \n\t" +" mov x11, x1 \n\t" // Load B row. +" ld1 {v28.d}[0], [x11], x3 \n\t" +" ld1 {v28.d}[1], [x11], x3 \n\t" +" ld1 {v29.d}[0], [x11], x3 \n\t" +" add x1, x1, #8 \n\t" +" fmla v0.2d, v24.2d, v28.d[0] \n\t" +" fmla v3.2d, v25.2d, v28.d[0] \n\t" +" fmla v6.2d, v26.2d, v28.d[0] \n\t" +" fmla v1.2d, v24.2d, v28.d[1] \n\t" +" fmla v4.2d, v25.2d, v28.d[1] \n\t" +" fmla v7.2d, v26.2d, v28.d[1] \n\t" +" fmla v2.2d, v24.2d, v29.d[0] \n\t" +" fmla v5.2d, v25.2d, v29.d[0] \n\t" +" fmla v8.2d, v26.2d, v29.d[0] \n\t" +" sub x8, x8, #1 \n\t" +BRANCH(K_LEFT_LOOP) +// +// Scale and write to memory. +LABEL(WRITE_MEM_PREP) +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). +" ld1r {v31.2d}, [x8] \n\t" +DSCALE9V(0,1,2,3,4,5,6,7,8,30,0) +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x7, #8 \n\t" // Check for column-storage. +BNE(WRITE_MEM_C) +// +// C storage in rows. +LABEL(WRITE_MEM_R) +" trn1 v20.2d, v0.2d, v1.2d \n\t" +" trn2 v21.2d, v0.2d, v1.2d \n\t" +" trn1 v22.2d, v3.2d, v4.2d \n\t" +" trn2 v23.2d, v3.2d, v4.2d \n\t" +" trn1 v24.2d, v6.2d, v7.2d \n\t" +" trn2 v25.2d, v6.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) +DLOADC_1V_1ELM_R_FWD(10,26,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(11,26,1,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(12,27,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(13,27,1,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(14,28,0,x9,0,x6) +DLOADC_1V_1ELM_R_FWD(15,28,1,x9,0,x6) +DSCALEA9V(20,21,22,23,24,25,2,5,8,10,11,12,13,14,15,26,27,28,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_1V_1ELM_R_FWD(20,2,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(21,2,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(22,5,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(23,5,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(24,8,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(25,8,1,x5,0,x6) +BRANCH(END_WRITE_MEM) +// +// C storage in columns. +LABEL(WRITE_MEM_C) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) +DLOADC_3V_C_FWD(12,15,18,x9,0,x7) +DLOADC_3V_C_FWD(13,16,19,x9,0,x7) +DLOADC_3V_C_FWD(14,17,20,x9,0,x7) +DSCALEA9V(0,1,2,3,4,5,6,7,8,12,13,14,15,16,17,18,19,20,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_3V_C_FWD(0,3,6,x5,0,x7) +DSTOREC_3V_C_FWD(1,4,7,x5,0,x7) +DSTOREC_3V_C_FWD(2,5,8,x5,0,x7) +// +// End of this microkernel. +LABEL(END_WRITE_MEM) +// +// End of execution. +LABEL(END_EXEC) +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_a] "m" (rs_a), + [cs_b] "m" (cs_b), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta) +: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11","v12","v13","v14","v15", + "v16","v17","v18","v19","v20","v21","v22","v23", + "v24","v25","v26","v27","v28","v29","v30","v31" + ); + +} + + diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c new file mode 100644 index 0000000000..43880063eb --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c @@ -0,0 +1,383 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rd_armv8a_int_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + assert( m0 <= 2 ); + assert( n0 <= 8 ); + + double *a_loc = a; + double *b_loc = b; + double *c_loc = c; + + uint64_t k_mker = k0 / 2; + uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); + + assert( cs_a == 1 ); + assert( rs_b == 1 ); + + // Registers used to store a 2x8x2 block of C (summing the last dimension). + // Total: 22 specified. + float64x2_t vc_00, vc_01, vc_02, vc_03, vc_04, vc_05, vc_06, vc_07; + float64x2_t vc_10, vc_11, vc_12, vc_13, vc_14, vc_15, vc_16, vc_17; + float64x2_t va_0, va_1; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_04 = (float64x2_t)vdupq_n_f64( 0 ); + vc_05 = (float64x2_t)vdupq_n_f64( 0 ); + vc_06 = (float64x2_t)vdupq_n_f64( 0 ); + vc_07 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_14 = (float64x2_t)vdupq_n_f64( 0 ); + vc_15 = (float64x2_t)vdupq_n_f64( 0 ); + vc_16 = (float64x2_t)vdupq_n_f64( 0 ); + vc_17 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k_mker > 0; --k_mker ) + { + // if ( m0 > 0 ) + va_0 = vld1q_f64( a_loc + rs_a * 0 ); + if ( m0 > 1 ) va_1 = vld1q_f64( a_loc + rs_a * 1 ); + // if ( n0 > 0 ) + vb_0 = vld1q_f64( b_loc + cs_b * 0 ); + if ( n0 > 1 ) vb_1 = vld1q_f64( b_loc + cs_b * 1 ); + if ( n0 > 2 ) vb_2 = vld1q_f64( b_loc + cs_b * 2 ); + if ( n0 > 3 ) vb_3 = vld1q_f64( b_loc + cs_b * 3 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( m0 > 1 ) + { + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + } + + if ( n0 > 4 ) { + vb_0 = vld1q_f64( b_loc + cs_b * 4 ); + if ( n0 > 5 ) vb_1 = vld1q_f64( b_loc + cs_b * 5 ); + if ( n0 > 6 ) vb_2 = vld1q_f64( b_loc + cs_b * 6 ); + if ( n0 > 7 ) vb_3 = vld1q_f64( b_loc + cs_b * 7 ); + + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + if ( n0 > 6 ) + { + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( m0 > 1 ) + { + vc_14 = vfmaq_f64( vc_14, va_1, vb_0 ); + vc_15 = vfmaq_f64( vc_15, va_1, vb_1 ); + if ( n0 > 6 ) + { + vc_16 = vfmaq_f64( vc_16, va_1, vb_2 ); + vc_17 = vfmaq_f64( vc_17, va_1, vb_3 ); + } + } + } + + a_loc += 2; + b_loc += 2; + } + + // Pay no care for O(1) details. + va_0 = (float64x2_t)vdupq_n_f64( 0 ); + va_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_0 = (float64x2_t)vdupq_n_f64( 0 ); + vb_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_3 = (float64x2_t)vdupq_n_f64( 0 ); + PRAGMA_NOUNROLL + for ( ; k_left > 0; --k_left ) + { + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 1, va_1, 0 ); + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( b_loc + cs_b * 0, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 1, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 2, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 3, vb_3, 0 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + + if ( n0 > 4 ) vb_0 = vld1q_lane_f64( b_loc + cs_b * 4, vb_0, 0 ); + if ( n0 > 5 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 5, vb_1, 0 ); + if ( n0 > 6 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 6, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 7, vb_3, 0 ); + + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + vc_14 = vfmaq_f64( vc_14, va_1, vb_0 ); + vc_15 = vfmaq_f64( vc_15, va_1, vb_1 ); + vc_16 = vfmaq_f64( vc_16, va_1, vb_2 ); + vc_17 = vfmaq_f64( vc_17, va_1, vb_3 ); + + a_loc += 1; + b_loc += 1; + } + + // Load alpha and beta. + // Note that here vb is used for alpha, in contrast to other kernels. + vb_0 = vld1q_dup_f64( alpha ); + va_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, vb_0 ); + vc_01 = vmulq_f64( vc_01, vb_0 ); + vc_02 = vmulq_f64( vc_02, vb_0 ); + vc_03 = vmulq_f64( vc_03, vb_0 ); + vc_04 = vmulq_f64( vc_04, vb_0 ); + vc_05 = vmulq_f64( vc_05, vb_0 ); + vc_06 = vmulq_f64( vc_06, vb_0 ); + vc_07 = vmulq_f64( vc_07, vb_0 ); + vc_10 = vmulq_f64( vc_10, vb_0 ); + vc_11 = vmulq_f64( vc_11, vb_0 ); + vc_12 = vmulq_f64( vc_12, vb_0 ); + vc_13 = vmulq_f64( vc_13, vb_0 ); + vc_14 = vmulq_f64( vc_14, vb_0 ); + vc_15 = vmulq_f64( vc_15, vb_0 ); + vc_16 = vmulq_f64( vc_16, vb_0 ); + vc_17 = vmulq_f64( vc_17, vb_0 ); + + if ( cs_c == 1 ) + { + // Row-storage. + vc_00 = vpaddq_f64( vc_00, vc_01 ); + vc_02 = vpaddq_f64( vc_02, vc_03 ); + vc_04 = vpaddq_f64( vc_04, vc_05 ); + vc_06 = vpaddq_f64( vc_06, vc_07 ); + vc_10 = vpaddq_f64( vc_10, vc_11 ); + vc_12 = vpaddq_f64( vc_12, vc_13 ); + vc_14 = vpaddq_f64( vc_14, vc_15 ); + vc_16 = vpaddq_f64( vc_16, vc_17 ); + + if ( n0 > 1 ) vb_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else if ( n0 > 0 ) vb_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, vb_0, 0 ); + if ( n0 > 3 ) vb_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n0 > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, vb_1, 0 ); + if ( n0 > 5 ) vb_2 = vld1q_f64 ( c_loc + 0 * rs_c + 4 ); + else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 * rs_c + 4, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 0 * rs_c + 6 ); + else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 * rs_c + 6, vb_3, 0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_1 ); + vc_04 = vfmaq_f64( vc_04, va_0, vb_2 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_3 ); + } + if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_02 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_02, 0 ); + if ( n0 > 5 ) vst1q_f64 ( c_loc + 0 * rs_c + 4, vc_04 ); + else if ( n0 > 4 ) vst1q_lane_f64( c_loc + 0 * rs_c + 4, vc_04, 0 ); + if ( n0 > 7 ) vst1q_f64 ( c_loc + 0 * rs_c + 6, vc_06 ); + else if ( n0 > 6 ) vst1q_lane_f64( c_loc + 0 * rs_c + 6, vc_06, 0 ); + + if ( m0 > 1 ) + { + if ( n0 > 1 ) vb_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else if ( n0 > 0 ) vb_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, vb_0, 0 ); + if ( n0 > 3 ) vb_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n0 > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, vb_1, 0 ); + if ( n0 > 5 ) vb_2 = vld1q_f64 ( c_loc + 1 * rs_c + 4 ); + else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 * rs_c + 4, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 1 * rs_c + 6 ); + else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 * rs_c + 6, vb_3, 0 ); + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_0, vb_1 ); + vc_14 = vfmaq_f64( vc_14, va_0, vb_2 ); + vc_16 = vfmaq_f64( vc_16, va_0, vb_3 ); + } + if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_12 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_12, 0 ); + if ( n0 > 5 ) vst1q_f64 ( c_loc + 1 * rs_c + 4, vc_14 ); + else if ( n0 > 4 ) vst1q_lane_f64( c_loc + 1 * rs_c + 4, vc_14, 0 ); + if ( n0 > 7 ) vst1q_f64 ( c_loc + 1 * rs_c + 6, vc_16 ); + else if ( n0 > 6 ) vst1q_lane_f64( c_loc + 1 * rs_c + 6, vc_16, 0 ); + } + } + else + { + // Column-storage. + vc_00 = vpaddq_f64( vc_00, vc_10 ); + vc_01 = vpaddq_f64( vc_01, vc_11 ); + vc_02 = vpaddq_f64( vc_02, vc_12 ); + vc_03 = vpaddq_f64( vc_03, vc_13 ); + vc_04 = vpaddq_f64( vc_04, vc_14 ); + vc_05 = vpaddq_f64( vc_05, vc_15 ); + vc_06 = vpaddq_f64( vc_06, vc_16 ); + vc_07 = vpaddq_f64( vc_07, vc_17 ); + + if ( m0 > 1 ) + { + // if ( n0 > 0 ) + vb_0 = vld1q_f64( c_loc + 0 + 0 * cs_c ); + if ( n0 > 1 ) vb_1 = vld1q_f64( c_loc + 0 + 1 * cs_c ); + if ( n0 > 2 ) vb_2 = vld1q_f64( c_loc + 0 + 2 * cs_c ); + if ( n0 > 3 ) vb_3 = vld1q_f64( c_loc + 0 + 3 * cs_c ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } + vst1q_f64( c_loc + 0 + 0 * cs_c, vc_00 ); + if ( n0 > 1 ) vst1q_f64( c_loc + 0 + 1 * cs_c, vc_01 ); + if ( n0 > 2 ) vst1q_f64( c_loc + 0 + 2 * cs_c, vc_02 ); + if ( n0 > 3 ) vst1q_f64( c_loc + 0 + 3 * cs_c, vc_03 ); + + if ( n0 > 4 ) vb_0 = vld1q_f64( c_loc + 0 + 4 * cs_c ); + if ( n0 > 5 ) vb_1 = vld1q_f64( c_loc + 0 + 5 * cs_c ); + if ( n0 > 6 ) vb_2 = vld1q_f64( c_loc + 0 + 6 * cs_c ); + if ( n0 > 7 ) vb_3 = vld1q_f64( c_loc + 0 + 7 * cs_c ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( n0 > 4 ) vst1q_f64( c_loc + 0 + 4 * cs_c, vc_04 ); + if ( n0 > 5 ) vst1q_f64( c_loc + 0 + 5 * cs_c, vc_05 ); + if ( n0 > 6 ) vst1q_f64( c_loc + 0 + 6 * cs_c, vc_06 ); + if ( n0 > 7 ) vst1q_f64( c_loc + 0 + 7 * cs_c, vc_07 ); + } + else + { + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, vb_3, 0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } + vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( n0 > 1 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_01, 0 ); + if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 ); + if ( n0 > 3 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_03, 0 ); + + if ( n0 > 4 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 4 * cs_c, vb_0, 0 ); + if ( n0 > 5 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 5 * cs_c, vb_1, 0 ); + if ( n0 > 6 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 6 * cs_c, vb_2, 0 ); + if ( n0 > 7 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 7 * cs_c, vb_3, 0 ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } + if ( n0 > 4 ) vst1q_lane_f64( c_loc + 0 + 4 * cs_c, vc_04, 0 ); + if ( n0 > 5 ) vst1q_lane_f64( c_loc + 0 + 5 * cs_c, vc_05, 0 ); + if ( n0 > 6 ) vst1q_lane_f64( c_loc + 0 + 6 * cs_c, vc_06, 0 ); + if ( n0 > 7 ) vst1q_lane_f64( c_loc + 0 + 7 * cs_c, vc_07, 0 ); + } + } + +} diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c new file mode 100644 index 0000000000..73e5f20fb7 --- /dev/null +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c @@ -0,0 +1,341 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rd_armv8a_int_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // if ( m0 == 3 && n0 == 4 ) + // { + // // Use fixed-size version if it is full 3x4. + // bli_dgemmsup_rd_armv8a_asm_3x4 + // ( + // conja, conjb, m0, n0, k0, + // alpha, a, rs_a, cs_a, b, rs_b, cs_b, + // beta, c, rs_c, cs_c, data, cntx + // ); + // return; + // } + + assert( m0 <= 3 ); + assert( n0 <= 4 ); + + double *a_loc = a; + double *b_loc = b; + double *c_loc = c; + + uint64_t k_mker = k0 / 2; + uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); + + assert( cs_a == 1 ); + assert( rs_b == 1 ); + + // Registers used to store a 3x4x2 block of C (summing the last dimension). + float64x2_t vc_00, vc_01, vc_02, vc_03; + float64x2_t vc_10, vc_11, vc_12, vc_13; + float64x2_t vc_20, vc_21, vc_22, vc_23; + float64x2_t va_0, va_1, va_2; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); + vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_22 = (float64x2_t)vdupq_n_f64( 0 ); + vc_23 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k_mker > 0; --k_mker ) + { + // if ( m0 > 0 ) + va_0 = vld1q_f64( a_loc + rs_a * 0 ); + if ( m0 > 1 ) va_1 = vld1q_f64( a_loc + rs_a * 1 ); + if ( m0 > 2 ) va_2 = vld1q_f64( a_loc + rs_a * 2 ); + // if ( n0 > 0 ) + vb_0 = vld1q_f64( b_loc + cs_b * 0 ); + if ( n0 > 1 ) vb_1 = vld1q_f64( b_loc + cs_b * 1 ); + if ( n0 > 2 ) vb_2 = vld1q_f64( b_loc + cs_b * 2 ); + if ( n0 > 3 ) vb_3 = vld1q_f64( b_loc + cs_b * 3 ); + a_loc += 2; + b_loc += 2; + + // 1-column case. + if ( n0 == 1 ) { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + continue; + } + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( m0 > 1 ) + { + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + } + if ( m0 > 2 ) { + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_2, vb_1 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_2 ); + vc_23 = vfmaq_f64( vc_23, va_2, vb_3 ); + } + } + + // Pay no care for O(1) details. + va_0 = (float64x2_t)vdupq_n_f64( 0 ); + va_1 = (float64x2_t)vdupq_n_f64( 0 ); + va_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_0 = (float64x2_t)vdupq_n_f64( 0 ); + vb_1 = (float64x2_t)vdupq_n_f64( 0 ); + vb_2 = (float64x2_t)vdupq_n_f64( 0 ); + vb_3 = (float64x2_t)vdupq_n_f64( 0 ); + PRAGMA_NOUNROLL + for ( ; k_left > 0; --k_left ) + { + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 1, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 2, va_2, 0 ); + // if ( n0 > 0 ) + vb_0 = vld1q_lane_f64( b_loc + cs_b * 0, vb_0, 0 ); + if ( n0 > 1 ) vb_1 = vld1q_lane_f64( b_loc + cs_b * 1, vb_1, 0 ); + if ( n0 > 2 ) vb_2 = vld1q_lane_f64( b_loc + cs_b * 2, vb_2, 0 ); + if ( n0 > 3 ) vb_3 = vld1q_lane_f64( b_loc + cs_b * 3, vb_3, 0 ); + + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_1 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_2 ); + vc_13 = vfmaq_f64( vc_13, va_1, vb_3 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_2, vb_1 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_2 ); + vc_23 = vfmaq_f64( vc_23, va_2, vb_3 ); + + a_loc += 1; + b_loc += 1; + } + + // Reduce. + vc_00 = vpaddq_f64( vc_00, vc_01 ); + vc_02 = vpaddq_f64( vc_02, vc_03 ); + vc_10 = vpaddq_f64( vc_10, vc_11 ); + vc_12 = vpaddq_f64( vc_12, vc_13 ); + vc_20 = vpaddq_f64( vc_20, vc_21 ); + vc_22 = vpaddq_f64( vc_22, vc_23 ); + + // Load alpha and beta. + va_0 = vld1q_dup_f64( alpha ); + vb_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, va_0 ); + vc_02 = vmulq_f64( vc_02, va_0 ); + vc_10 = vmulq_f64( vc_10, va_0 ); + vc_12 = vmulq_f64( vc_12, va_0 ); + vc_20 = vmulq_f64( vc_20, va_0 ); + vc_22 = vmulq_f64( vc_22, va_0 ); + + if ( cs_c == 1 ) + { + // Row-storage. + // if ( m0 > 0 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_02 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_02, 0 ); + } + if ( m0 > 1 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_12 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_12, 0 ); + } + if ( m0 > 2 ) + { + if ( n0 > 1 ) va_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else if ( n0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, va_0, 0 ); + if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, va_1, 0 ); + + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_1, vb_0 ); + } + + if ( n0 > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n0 > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_22 ); + else if ( n0 > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_22, 0 ); + } + } + else + { + // Column-storage. + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 0 * cs_c, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 0 * cs_c, va_2, 0 ); + if ( n0 > 1 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, va_0, 1 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 1 * cs_c, va_1, 1 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 1 * cs_c, va_2, 1 ); + } + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + } + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 0 * cs_c, vc_10, 0 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 0 * cs_c, vc_20, 0 ); + if ( n0 > 1 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_00, 1 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 1 * cs_c, vc_10, 1 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 1 * cs_c, vc_20, 1 ); + } + + if ( n0 > 2 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, va_0, 0 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 2 * cs_c, va_1, 0 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 2 * cs_c, va_2, 0 ); + } + if ( n0 > 3 ) + { + if ( m0 > 0 ) va_0 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, va_0, 1 ); + if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 3 * cs_c, va_1, 1 ); + if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 3 * cs_c, va_2, 1 ); + } + if ( !b_iszr ) + { + vc_02 = vfmaq_f64( vc_02, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_0 ); + } + if ( n0 > 2 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 2 * cs_c, vc_12, 0 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 2 * cs_c, vc_22, 0 ); + } + if ( n0 > 3 ) + { + if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_02, 1 ); + if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 3 * cs_c, vc_12, 1 ); + if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 3 * cs_c, vc_22, 1 ); + } + } + +} + diff --git a/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c new file mode 100644 index 0000000000..16af42ade6 --- /dev/null +++ b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d3x8mn.c @@ -0,0 +1,393 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rv_armv8a_int_3x8mn + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a0, inc_t rs_a, inc_t cs_a, + double* restrict b0, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c0, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Unlike the rd case, this rv case does not impose restriction upon + // maximal m & n. + + double *a_loc; + double *b_loc, *b_in; + double *c_loc, *c_in; + + dim_t n; + dim_t k; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t b_iszr = ( *beta == 0.0 ); + assert( cs_b == 1 ); + + // Registers used to store a 3x8 block of C. + float64x2_t vc_00, vc_01, vc_02, vc_03; + float64x2_t vc_10, vc_11, vc_12, vc_13; + float64x2_t vc_20, vc_21, vc_22, vc_23; + float64x2_t va_0, va_1; + float64x2_t vb_0, vb_1, vb_2, vb_3; + + PRAGMA_NOUNROLL + for ( ; m0 > 0; m0 -= 3 ) + { + n = n0; + b_in = b0; + c_in = c0; + + PRAGMA_NOUNROLL + for ( ; n > 0; n -= 8 ) + { + a_loc = a0; + b_loc = b_in; + c_loc = c_in; + k = k0; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); + vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_02 = (float64x2_t)vdupq_n_f64( 0 ); + vc_03 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); + vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_12 = (float64x2_t)vdupq_n_f64( 0 ); + vc_13 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); + vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_22 = (float64x2_t)vdupq_n_f64( 0 ); + vc_23 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k > 0; --k ) + { + // A columns. + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_0 = vld1q_lane_f64( a_loc + rs_a * 1, va_0, 1 ); + if ( m0 > 2 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 2, va_1, 0 ); + // B rows. + if ( n > 1 ) vb_0 = vld1q_f64 ( b_loc + 0 ); + else vb_0 = vld1q_lane_f64( b_loc + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( b_loc + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( b_loc + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( b_loc + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( b_loc + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( b_loc + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( b_loc + 6, vb_3, 0 ); + a_loc += cs_a; + b_loc += rs_b; + + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_01 = vfmaq_laneq_f64( vc_01, vb_1, va_0, 0 ); + vc_02 = vfmaq_laneq_f64( vc_02, vb_2, va_0, 0 ); + vc_03 = vfmaq_laneq_f64( vc_03, vb_3, va_0, 0 ); + } + if ( m0 > 1 ) + { + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_11 = vfmaq_laneq_f64( vc_11, vb_1, va_0, 1 ); + vc_12 = vfmaq_laneq_f64( vc_12, vb_2, va_0, 1 ); + vc_13 = vfmaq_laneq_f64( vc_13, vb_3, va_0, 1 ); + } + if ( m0 > 2 ) + { + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + vc_21 = vfmaq_laneq_f64( vc_21, vb_1, va_1, 0 ); + vc_22 = vfmaq_laneq_f64( vc_22, vb_2, va_1, 0 ); + vc_23 = vfmaq_laneq_f64( vc_23, vb_3, va_1, 0 ); + } + } + + // Load alpha and beta. + // Note that here vb is used for alpha, in contrast to other kernels. + vb_0 = vld1q_dup_f64( alpha ); + va_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, vb_0 ); + vc_01 = vmulq_f64( vc_01, vb_0 ); + vc_02 = vmulq_f64( vc_02, vb_0 ); + vc_03 = vmulq_f64( vc_03, vb_0 ); + vc_10 = vmulq_f64( vc_10, vb_0 ); + vc_11 = vmulq_f64( vc_11, vb_0 ); + vc_12 = vmulq_f64( vc_12, vb_0 ); + vc_13 = vmulq_f64( vc_13, vb_0 ); + vc_20 = vmulq_f64( vc_20, vb_0 ); + vc_21 = vmulq_f64( vc_21, vb_0 ); + vc_22 = vmulq_f64( vc_22, vb_0 ); + vc_23 = vmulq_f64( vc_23, vb_0 ); + + if ( cs_c == 1 ) + { + // Store in rows. + // + // if ( m0 > 0 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 0 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 0 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, vb_0, va_0 ); + vc_01 = vfmaq_f64( vc_01, vb_1, va_0 ); + vc_02 = vfmaq_f64( vc_02, vb_2, va_0 ); + vc_03 = vfmaq_f64( vc_03, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_01 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_01, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 0 * rs_c + 4, vc_02 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 0 * rs_c + 4, vc_02, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 0 * rs_c + 6, vc_03 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 0 * rs_c + 6, vc_03, 0 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 1 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 1 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, vb_0, va_0 ); + vc_11 = vfmaq_f64( vc_11, vb_1, va_0 ); + vc_12 = vfmaq_f64( vc_12, vb_2, va_0 ); + vc_13 = vfmaq_f64( vc_13, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_11 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_11, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 1 * rs_c + 4, vc_12 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 1 * rs_c + 4, vc_12, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 1 * rs_c + 6, vc_13 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 1 * rs_c + 6, vc_13, 0 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 1 ) vb_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else vb_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, vb_1, 0 ); + if ( n > 5 ) vb_2 = vld1q_f64 ( c_loc + 2 * rs_c + 4 ); + else if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 2 * rs_c + 4, vb_2, 0 ); + if ( n > 7 ) vb_3 = vld1q_f64 ( c_loc + 2 * rs_c + 6 ); + else if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 2 * rs_c + 6, vb_3, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, vb_0, va_0 ); + vc_21 = vfmaq_f64( vc_21, vb_1, va_0 ); + vc_22 = vfmaq_f64( vc_22, vb_2, va_0 ); + vc_23 = vfmaq_f64( vc_23, vb_3, va_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_21 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_21, 0 ); + if ( n > 5 ) vst1q_f64 ( c_loc + 2 * rs_c + 4, vc_22 ); + else if ( n > 4 ) vst1q_lane_f64( c_loc + 2 * rs_c + 4, vc_22, 0 ); + if ( n > 7 ) vst1q_f64 ( c_loc + 2 * rs_c + 6, vc_23 ); + else if ( n > 6 ) vst1q_lane_f64( c_loc + 2 * rs_c + 6, vc_23, 0 ); + } + } + else + { + // Store in columns. + // No in-reg transpose here. + // + // if ( m0 > 0 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, vb_0, va_0 ); + vc_01 = vfmaq_f64( vc_01, vb_1, va_0 ); + vc_02 = vfmaq_f64( vc_02, vb_2, va_0 ); + vc_03 = vfmaq_f64( vc_03, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_00, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_01, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 0 + 3 * cs_c, vc_01, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 0 + 4 * cs_c, vc_02, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 0 + 5 * cs_c, vc_02, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 0 + 6 * cs_c, vc_03, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 0 + 7 * cs_c, vc_03, 1 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 1 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 1 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 1 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 1 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 1 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 1 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, vb_0, va_0 ); + vc_11 = vfmaq_f64( vc_11, vb_1, va_0 ); + vc_12 = vfmaq_f64( vc_12, vb_2, va_0 ); + vc_13 = vfmaq_f64( vc_13, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 1 + 0 * cs_c, vc_10, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 + 1 * cs_c, vc_10, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 1 + 2 * cs_c, vc_11, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 1 + 3 * cs_c, vc_11, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 1 + 4 * cs_c, vc_12, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 1 + 5 * cs_c, vc_12, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 1 + 6 * cs_c, vc_13, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 1 + 7 * cs_c, vc_13, 1 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 0 ) vb_0 = vld1q_lane_f64( c_loc + 2 + 0 * cs_c, vb_0, 0 ); + if ( n > 1 ) vb_0 = vld1q_lane_f64( c_loc + 2 + 1 * cs_c, vb_0, 1 ); + if ( n > 2 ) vb_1 = vld1q_lane_f64( c_loc + 2 + 2 * cs_c, vb_1, 0 ); + if ( n > 3 ) vb_1 = vld1q_lane_f64( c_loc + 2 + 3 * cs_c, vb_1, 1 ); + if ( n > 4 ) vb_2 = vld1q_lane_f64( c_loc + 2 + 4 * cs_c, vb_2, 0 ); + if ( n > 5 ) vb_2 = vld1q_lane_f64( c_loc + 2 + 5 * cs_c, vb_2, 1 ); + if ( n > 6 ) vb_3 = vld1q_lane_f64( c_loc + 2 + 6 * cs_c, vb_3, 0 ); + if ( n > 7 ) vb_3 = vld1q_lane_f64( c_loc + 2 + 7 * cs_c, vb_3, 1 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, vb_0, va_0 ); + vc_21 = vfmaq_f64( vc_21, vb_1, va_0 ); + vc_22 = vfmaq_f64( vc_22, vb_2, va_0 ); + vc_23 = vfmaq_f64( vc_23, vb_3, va_0 ); + } + + // Store. + if ( n > 0 ) vst1q_lane_f64( c_loc + 2 + 0 * cs_c, vc_20, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 2 + 1 * cs_c, vc_20, 1 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 + 2 * cs_c, vc_21, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 2 + 3 * cs_c, vc_21, 1 ); + if ( n > 4 ) vst1q_lane_f64( c_loc + 2 + 4 * cs_c, vc_22, 0 ); + if ( n > 5 ) vst1q_lane_f64( c_loc + 2 + 5 * cs_c, vc_22, 1 ); + if ( n > 6 ) vst1q_lane_f64( c_loc + 2 + 6 * cs_c, vc_23, 0 ); + if ( n > 7 ) vst1q_lane_f64( c_loc + 2 + 7 * cs_c, vc_23, 1 ); + } + } + + b_in += ps_b; + c_in += 8 * cs_c; + } + + a0 += ps_a; + c0 += 3 * rs_c; + } +} + diff --git a/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c new file mode 100644 index 0000000000..8bbd87f1f6 --- /dev/null +++ b/kernels/armv8a/3/sup/d6x4/bli_gemmsup_rv_armv8a_int_d6x4mn.c @@ -0,0 +1,481 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Supplimentary dynamic-size gemmsup. + +#include "blis.h" +#include "assert.h" +#include + +#if defined(__clang__) +#define PRAGMA_NOUNROLL _Pragma("nounroll") +#define PRAGMA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define PRAGMA_NOUNROLL _Pragma("GCC unroll 1") +#define PRAGMA_UNROLL _Pragma("GCC unroll 2") +#else +#define PRAGMA_NOUNROLL +#define PRAGMA_UNROLL +#endif + +/* + * As these kernels requires num. of vregs about half of the total 32, + * it should be all right to implement w/ intrinsics. + * + * c.f. https://www.youtube.com/watch?v=R2hQOVjRwVE . + */ +void bli_dgemmsup_rv_armv8a_int_6x4mn + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a0, inc_t rs_a, inc_t cs_a, + double* restrict b0, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c0, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Unlike the rd case, this rv case does not impose restriction upon + // maximal m & n. + + double *a_loc; + double *b_loc, *b_in; + double *c_loc, *c_in; + + dim_t n; + dim_t k; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t b_iszr = ( *beta == 0.0 ); + assert( cs_b == 1 ); + + // Registers used to store a 6x4 block of C. + float64x2_t vc_00, vc_01; + float64x2_t vc_10, vc_11; + float64x2_t vc_20, vc_21; + float64x2_t vc_30, vc_31; + float64x2_t vc_40, vc_41; + float64x2_t vc_50, vc_51; + float64x2_t va_0, va_1, va_2; + float64x2_t vb_0, vb_1; + + PRAGMA_NOUNROLL + for ( ; m0 > 0; m0 -= 6 ) + { + n = n0; + b_in = b0; + c_in = c0; + + PRAGMA_NOUNROLL + for ( ; n > 0; n -= 4 ) + { + a_loc = a0; + b_loc = b_in; + c_loc = c_in; + k = k0; + + vc_00 = (float64x2_t)vdupq_n_f64( 0 ); vc_01 = (float64x2_t)vdupq_n_f64( 0 ); + vc_10 = (float64x2_t)vdupq_n_f64( 0 ); vc_11 = (float64x2_t)vdupq_n_f64( 0 ); + vc_20 = (float64x2_t)vdupq_n_f64( 0 ); vc_21 = (float64x2_t)vdupq_n_f64( 0 ); + vc_30 = (float64x2_t)vdupq_n_f64( 0 ); vc_31 = (float64x2_t)vdupq_n_f64( 0 ); + vc_40 = (float64x2_t)vdupq_n_f64( 0 ); vc_41 = (float64x2_t)vdupq_n_f64( 0 ); + vc_50 = (float64x2_t)vdupq_n_f64( 0 ); vc_51 = (float64x2_t)vdupq_n_f64( 0 ); + + PRAGMA_UNROLL + for ( ; k > 0; --k ) + { + // A columns. + // if ( m0 > 0 ) + va_0 = vld1q_lane_f64( a_loc + rs_a * 0, va_0, 0 ); + if ( m0 > 1 ) va_0 = vld1q_lane_f64( a_loc + rs_a * 1, va_0, 1 ); + if ( m0 > 2 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 2, va_1, 0 ); + if ( m0 > 3 ) va_1 = vld1q_lane_f64( a_loc + rs_a * 3, va_1, 1 ); + if ( m0 > 4 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 4, va_2, 0 ); + if ( m0 > 5 ) va_2 = vld1q_lane_f64( a_loc + rs_a * 5, va_2, 1 ); + // B rows. + if ( n > 1 ) vb_0 = vld1q_f64 ( b_loc + 0 ); + else vb_0 = vld1q_lane_f64( b_loc + 0, vb_0, 0 ); + if ( n > 3 ) vb_1 = vld1q_f64 ( b_loc + 2 ); + else if ( n > 2 ) vb_1 = vld1q_lane_f64( b_loc + 2, vb_1, 0 ); + a_loc += cs_a; + b_loc += rs_b; + + // One or two-column case. + if ( n <= 2 ) + { + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + } + if ( m0 > 3 ) + { + vc_30 = vfmaq_laneq_f64( vc_30, vb_0, va_1, 1 ); + vc_40 = vfmaq_laneq_f64( vc_40, vb_0, va_2, 0 ); + vc_50 = vfmaq_laneq_f64( vc_50, vb_0, va_2, 1 ); + } + continue; + } + + // Three or four-column case. Moderately decrease num. of FMLA instructions + // according to m and n. + // if ( m0 > 0 ) + { + vc_00 = vfmaq_laneq_f64( vc_00, vb_0, va_0, 0 ); + vc_01 = vfmaq_laneq_f64( vc_01, vb_1, va_0, 0 ); + vc_10 = vfmaq_laneq_f64( vc_10, vb_0, va_0, 1 ); + vc_11 = vfmaq_laneq_f64( vc_11, vb_1, va_0, 1 ); + } + if ( m0 > 2 ) + { + vc_20 = vfmaq_laneq_f64( vc_20, vb_0, va_1, 0 ); + vc_21 = vfmaq_laneq_f64( vc_21, vb_1, va_1, 0 ); + vc_30 = vfmaq_laneq_f64( vc_30, vb_0, va_1, 1 ); + vc_31 = vfmaq_laneq_f64( vc_31, vb_1, va_1, 1 ); + } + if ( m0 > 4 ) + { + vc_40 = vfmaq_laneq_f64( vc_40, vb_0, va_2, 0 ); + vc_41 = vfmaq_laneq_f64( vc_41, vb_1, va_2, 0 ); + vc_50 = vfmaq_laneq_f64( vc_50, vb_0, va_2, 1 ); + vc_51 = vfmaq_laneq_f64( vc_51, vb_1, va_2, 1 ); + } + } + + // Load alpha and beta. + va_0 = vld1q_dup_f64( alpha ); + vb_0 = vld1q_dup_f64( beta ); + + // Scale. + vc_00 = vmulq_f64( vc_00, va_0 ); vc_01 = vmulq_f64( vc_01, va_0 ); + vc_10 = vmulq_f64( vc_10, va_0 ); vc_11 = vmulq_f64( vc_11, va_0 ); + vc_20 = vmulq_f64( vc_20, va_0 ); vc_21 = vmulq_f64( vc_21, va_0 ); + vc_30 = vmulq_f64( vc_30, va_0 ); vc_31 = vmulq_f64( vc_31, va_0 ); + vc_40 = vmulq_f64( vc_40, va_0 ); vc_41 = vmulq_f64( vc_41, va_0 ); + vc_50 = vmulq_f64( vc_50, va_0 ); vc_51 = vmulq_f64( vc_51, va_0 ); + + if ( cs_c == 1 ) + { + // Store in rows. + // if ( m0 > 0 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 0 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 0 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); + else vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_01 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 0 * rs_c + 2, vc_01, 0 ); + } + if ( m0 > 1 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 1 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 1 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_11 = vfmaq_f64( vc_11, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); + else vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_11 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 1 * rs_c + 2, vc_11, 0 ); + } + if ( m0 > 2 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 2 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 2 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); + vc_21 = vfmaq_f64( vc_21, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); + else vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 2 * rs_c + 2, vc_21 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * rs_c + 2, vc_21, 0 ); + } + if ( m0 > 3 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 3 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 3 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 3 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 3 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_30 = vfmaq_f64( vc_30, va_0, vb_0 ); + vc_31 = vfmaq_f64( vc_31, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 3 * rs_c + 0, vc_30 ); + else vst1q_lane_f64( c_loc + 3 * rs_c + 0, vc_30, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 3 * rs_c + 2, vc_31 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 3 * rs_c + 2, vc_31, 0 ); + } + if ( m0 > 4 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 4 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 4 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 4 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 4 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_40 = vfmaq_f64( vc_40, va_0, vb_0 ); + vc_41 = vfmaq_f64( vc_41, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 4 * rs_c + 0, vc_40 ); + else vst1q_lane_f64( c_loc + 4 * rs_c + 0, vc_40, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 4 * rs_c + 2, vc_41 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 4 * rs_c + 2, vc_41, 0 ); + } + if ( m0 > 5 ) + { + // Load. + if ( n > 1 ) va_0 = vld1q_f64 ( c_loc + 5 * rs_c + 0 ); + else va_0 = vld1q_lane_f64( c_loc + 5 * rs_c + 0, va_0, 0 ); + if ( n > 3 ) va_1 = vld1q_f64 ( c_loc + 5 * rs_c + 2 ); + else if ( n > 2 ) va_1 = vld1q_lane_f64( c_loc + 5 * rs_c + 2, va_1, 0 ); + + // Scale. + if ( !b_iszr ) + { + vc_50 = vfmaq_f64( vc_50, va_0, vb_0 ); + vc_51 = vfmaq_f64( vc_51, va_1, vb_0 ); + } + + // Store. + if ( n > 1 ) vst1q_f64 ( c_loc + 5 * rs_c + 0, vc_50 ); + else vst1q_lane_f64( c_loc + 5 * rs_c + 0, vc_50, 0 ); + if ( n > 3 ) vst1q_f64 ( c_loc + 5 * rs_c + 2, vc_51 ); + else if ( n > 2 ) vst1q_lane_f64( c_loc + 5 * rs_c + 2, vc_51, 0 ); + } + } + else + { + // Store in columns. + + // Rename some vectors. +#define VCOL0 va_0 +#define VCOL1 va_1 +#define VCOL2 va_2 +#define VCOL3 vb_1 +#define VTMP0 vc_00 +#define VTMP1 vc_01 +#define VTMP2 vc_10 +#define VTMP3 vc_11 + // if ( m0 > 0 ) + { + VCOL0 = vtrn1q_f64(vc_00, vc_10); + VCOL1 = vtrn2q_f64(vc_00, vc_10); + VCOL2 = vtrn1q_f64(vc_01, vc_11); + VCOL3 = vtrn2q_f64(vc_01, vc_11); + + if ( m0 > 1 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 0 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 0 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 0 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 0, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 0, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 0, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 0, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 0, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 0, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 0, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 0, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 0, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 0, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 0, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 0, VCOL3, 0 ); + } + } + if ( m0 > 2 ) + { + VCOL0 = vtrn1q_f64(vc_20, vc_30); + VCOL1 = vtrn2q_f64(vc_20, vc_30); + VCOL2 = vtrn1q_f64(vc_21, vc_31); + VCOL3 = vtrn2q_f64(vc_21, vc_31); + + if ( m0 > 3 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 2 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 2 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 2 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 2 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 2, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 2, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 2, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 2, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 2, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 2, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 2, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 2, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 2, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 2, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 2, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 2, VCOL3, 0 ); + } + } + if ( m0 > 4 ) + { + VCOL0 = vtrn1q_f64(vc_40, vc_50); + VCOL1 = vtrn2q_f64(vc_40, vc_50); + VCOL2 = vtrn1q_f64(vc_41, vc_51); + VCOL3 = vtrn2q_f64(vc_41, vc_51); + + if ( m0 > 5 ) + { + if ( n > 0 ) VTMP0 = vld1q_f64( c_loc + 0 * cs_c + 4 ); + if ( n > 1 ) VTMP1 = vld1q_f64( c_loc + 1 * cs_c + 4 ); + if ( n > 2 ) VTMP2 = vld1q_f64( c_loc + 2 * cs_c + 4 ); + if ( n > 3 ) VTMP3 = vld1q_f64( c_loc + 3 * cs_c + 4 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_f64( c_loc + 0 * cs_c + 4, VCOL0 ); + if ( n > 1 ) vst1q_f64( c_loc + 1 * cs_c + 4, VCOL1 ); + if ( n > 2 ) vst1q_f64( c_loc + 2 * cs_c + 4, VCOL2 ); + if ( n > 3 ) vst1q_f64( c_loc + 3 * cs_c + 4, VCOL3 ); + } + else + { + if ( n > 0 ) VTMP0 = vld1q_lane_f64( c_loc + 0 * cs_c + 4, VTMP0, 0 ); + if ( n > 1 ) VTMP1 = vld1q_lane_f64( c_loc + 1 * cs_c + 4, VTMP1, 0 ); + if ( n > 2 ) VTMP2 = vld1q_lane_f64( c_loc + 2 * cs_c + 4, VTMP2, 0 ); + if ( n > 3 ) VTMP3 = vld1q_lane_f64( c_loc + 3 * cs_c + 4, VTMP3, 0 ); + if ( !b_iszr ) + { + VCOL0 = vfmaq_f64( VCOL0, VTMP0, vb_0 ); + VCOL1 = vfmaq_f64( VCOL1, VTMP1, vb_0 ); + VCOL2 = vfmaq_f64( VCOL2, VTMP2, vb_0 ); + VCOL3 = vfmaq_f64( VCOL3, VTMP3, vb_0 ); + } + if ( n > 0 ) vst1q_lane_f64( c_loc + 0 * cs_c + 4, VCOL0, 0 ); + if ( n > 1 ) vst1q_lane_f64( c_loc + 1 * cs_c + 4, VCOL1, 0 ); + if ( n > 2 ) vst1q_lane_f64( c_loc + 2 * cs_c + 4, VCOL2, 0 ); + if ( n > 3 ) vst1q_lane_f64( c_loc + 3 * cs_c + 4, VCOL3, 0 ); + } + } + } + + b_in += ps_b; + c_in += 4 * cs_c; + } + + a0 += ps_a; + c0 += 6 * rs_c; + } +} + diff --git a/kernels/armv8a/bli_kernels_armv8a.h b/kernels/armv8a/bli_kernels_armv8a.h index f3c01985a9..b7ab755412 100644 --- a/kernels/armv8a/bli_kernels_armv8a.h +++ b/kernels/armv8a/bli_kernels_armv8a.h @@ -32,5 +32,30 @@ */ +PACKM_KER_PROT( float, s, packm_armv8a_int_8xk ) +PACKM_KER_PROT( float, s, packm_armv8a_int_12xk ) +PACKM_KER_PROT( double, d, packm_armv8a_int_6xk ) +PACKM_KER_PROT( double, d, packm_armv8a_int_8xk ) + GEMM_UKR_PROT( float, s, gemm_armv8a_asm_8x12 ) GEMM_UKR_PROT( double, d, gemm_armv8a_asm_6x8 ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_6x8r ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_8x4 ) +// GEMM_UKR_PROT( double, d, gemm_armv8a_asm_4x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_4x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_4x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_asm_8x4m ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_int_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_int_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_armv8a_asm_6x3 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_int_6x4mn ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armv8a_int_3x8mn ) + diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 8d0060b2f5..3227858b6a 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -15,7 +15,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/kernels/haswell/1m/CMakeLists.txt b/kernels/haswell/1m/CMakeLists.txt index 9130e97f15..1fdada82fa 100644 --- a/kernels/haswell/1m/CMakeLists.txt +++ b/kernels/haswell/1m/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] add_library(haswell_1m OBJECT diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 7a3478cb29..2c59c34a7f 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -2246,7 +2246,6 @@ void bli_zgemm_haswell_asm_3x4 { if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; - else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; } if(beta->imag == 0.0)// (beta is real) diff --git a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c index 82e5a25435..e79f5ccfac 100644 --- a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c +++ b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -20,14 +20,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c index b1ed2abf74..22e29115e9 100644 --- a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c +++ b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -20,14 +20,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/piledriver/3/CMakeLists.txt b/kernels/piledriver/3/CMakeLists.txt deleted file mode 100644 index 877419489f..0000000000 --- a/kernels/piledriver/3/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_piledriver_asm_d8x3.c) diff --git a/kernels/piledriver/CMakeLists.txt b/kernels/piledriver/CMakeLists.txt deleted file mode 100644 index 3c25f4b48e..0000000000 --- a/kernels/piledriver/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_kernels_piledriver.h) - -#add_subdirectory(3) diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c index 5735a5911a..8233e53ac4 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -20,14 +20,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 038920b834..8d4f484311 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -20,14 +20,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index 572045832d..750fcbd633 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -20,14 +20,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen/1/bli_addv_zen_int.c b/kernels/zen/1/bli_addv_zen_int.c new file mode 100644 index 0000000000..e64462520f --- /dev/null +++ b/kernels/zen/1/bli_addv_zen_int.c @@ -0,0 +1,1825 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +void bli_saddv_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m256 yv[16]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + float *x0 = x; + float *y0 = y; + + if ( incx == 1 && incy ==1 ) + { + // For loop with n & ~0x7F => n & 0xFFFFFF80 masks the lower bits and results in multiples of 128 + // for example if n = 255 + // n & ~0x7F results in 128: copy from 0 to 128 happens in first loop + // n & ~0x3F results in 192: copy from 128 to 192 happens in second loop + // n & ~0x1F results in 224: copy from 128 to 192 happens in third loop and so on. + for ( ; i < (n & (~0x7F)); i += 128 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_ps( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_ps( ( y0 + 3*num_elem_per_reg ), yv[3] ); + _mm256_storeu_ps( ( y0 + 4*num_elem_per_reg ), yv[4] ); + _mm256_storeu_ps( ( y0 + 5*num_elem_per_reg ), yv[5] ); + _mm256_storeu_ps( ( y0 + 6*num_elem_per_reg ), yv[6] ); + _mm256_storeu_ps( ( y0 + 7*num_elem_per_reg ), yv[7] ); + + yv[8] = _mm256_loadu_ps( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_ps( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_ps( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_ps( y0 + 11*num_elem_per_reg ); + yv[12] = _mm256_loadu_ps( y0 + 12*num_elem_per_reg ); + yv[13] = _mm256_loadu_ps( y0 + 13*num_elem_per_reg ); + yv[14] = _mm256_loadu_ps( y0 + 14*num_elem_per_reg ); + yv[15] = _mm256_loadu_ps( y0 + 15*num_elem_per_reg ); + + yv[8] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 8*num_elem_per_reg ), + yv[8] + ); + yv[9] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 9*num_elem_per_reg ), + yv[9] + ); + yv[10] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 10*num_elem_per_reg ), + yv[10] + ); + yv[11] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 11*num_elem_per_reg ), + yv[11] + ); + yv[12] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 12*num_elem_per_reg ), + yv[12] + ); + yv[13] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 13*num_elem_per_reg ), + yv[13] + ); + yv[14] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 14*num_elem_per_reg ), + yv[14] + ); + yv[15] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 15*num_elem_per_reg ), + yv[15] + ); + + _mm256_storeu_ps( ( y0 + 8*num_elem_per_reg ), yv[8] ); + _mm256_storeu_ps( ( y0 + 9*num_elem_per_reg ), yv[9] ); + _mm256_storeu_ps( ( y0 + 10*num_elem_per_reg ), yv[10] ); + _mm256_storeu_ps( ( y0 + 11*num_elem_per_reg ), yv[11] ); + _mm256_storeu_ps( ( y0 + 12*num_elem_per_reg ), yv[12] ); + _mm256_storeu_ps( ( y0 + 13*num_elem_per_reg ), yv[13] ); + _mm256_storeu_ps( ( y0 + 14*num_elem_per_reg ), yv[14] ); + _mm256_storeu_ps( ( y0 + 15*num_elem_per_reg ), yv[15] ); + + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x3F)); i += 64 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_ps( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_ps( ( y0 + 3*num_elem_per_reg ), yv[3] ); + _mm256_storeu_ps( ( y0 + 4*num_elem_per_reg ), yv[4] ); + _mm256_storeu_ps( ( y0 + 5*num_elem_per_reg ), yv[5] ); + _mm256_storeu_ps( ( y0 + 6*num_elem_per_reg ), yv[6] ); + _mm256_storeu_ps( ( y0 + 7*num_elem_per_reg ), yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x1F)); i += 32 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + + _mm256_storeu_ps( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_ps( ( y0 + 3*num_elem_per_reg ), yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x0F)); i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + + _mm256_storeu_ps( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1*num_elem_per_reg ), yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x07)); i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + + _mm256_storeu_ps( ( y0 + 0*num_elem_per_reg ), yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + } + + // Handling fringe cases or non-unit strided vectors + for ( ; i < n; i += 1 ) + { + *y0 += *x0; + + x0 += incx; + y0 += incy; + } +} + +void bli_daddv_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + __m256d yv[16]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + double *x0 = x; + double *y0 = y; + + if ( incx == 1 && incy ==1 ) + { + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // the copy operation will be done for the multiples of 64 + for ( ; i < (n & (~0x3F)); i += 64 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_pd( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_pd( ( y0 + 3*num_elem_per_reg ), yv[3] ); + _mm256_storeu_pd( ( y0 + 4*num_elem_per_reg ), yv[4] ); + _mm256_storeu_pd( ( y0 + 5*num_elem_per_reg ), yv[5] ); + _mm256_storeu_pd( ( y0 + 6*num_elem_per_reg ), yv[6] ); + _mm256_storeu_pd( ( y0 + 7*num_elem_per_reg ), yv[7] ); + + yv[8] = _mm256_loadu_pd( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_pd( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_pd( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_pd( y0 + 11*num_elem_per_reg ); + yv[12] = _mm256_loadu_pd( y0 + 12*num_elem_per_reg ); + yv[13] = _mm256_loadu_pd( y0 + 13*num_elem_per_reg ); + yv[14] = _mm256_loadu_pd( y0 + 14*num_elem_per_reg ); + yv[15] = _mm256_loadu_pd( y0 + 15*num_elem_per_reg ); + + yv[8] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 8*num_elem_per_reg ), + yv[8] + ); + yv[9] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 9*num_elem_per_reg ), + yv[9] + ); + yv[10] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 10*num_elem_per_reg ), + yv[10] + ); + yv[11] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 11*num_elem_per_reg ), + yv[11] + ); + yv[12] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 12*num_elem_per_reg ), + yv[12] + ); + yv[13] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 13*num_elem_per_reg ), + yv[13] + ); + yv[14] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 14*num_elem_per_reg ), + yv[14] + ); + yv[15] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 15*num_elem_per_reg ), + yv[15] + ); + + _mm256_storeu_pd( ( y0 + 8*num_elem_per_reg ), yv[8] ); + _mm256_storeu_pd( ( y0 + 9*num_elem_per_reg ), yv[9] ); + _mm256_storeu_pd( ( y0 + 10*num_elem_per_reg ), yv[10] ); + _mm256_storeu_pd( ( y0 + 11*num_elem_per_reg ), yv[11] ); + _mm256_storeu_pd( ( y0 + 12*num_elem_per_reg ), yv[12] ); + _mm256_storeu_pd( ( y0 + 13*num_elem_per_reg ), yv[13] ); + _mm256_storeu_pd( ( y0 + 14*num_elem_per_reg ), yv[14] ); + _mm256_storeu_pd( ( y0 + 15*num_elem_per_reg ), yv[15] ); + + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x1F)); i += 32 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_pd( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_pd( ( y0 + 3*num_elem_per_reg ), yv[3] ); + _mm256_storeu_pd( ( y0 + 4*num_elem_per_reg ), yv[4] ); + _mm256_storeu_pd( ( y0 + 5*num_elem_per_reg ), yv[5] ); + _mm256_storeu_pd( ( y0 + 6*num_elem_per_reg ), yv[6] ); + _mm256_storeu_pd( ( y0 + 7*num_elem_per_reg ), yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x0F)); i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + + _mm256_storeu_pd( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm256_storeu_pd( ( y0 + 3*num_elem_per_reg ), yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x07)); i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + + _mm256_storeu_pd( ( y0 + 0*num_elem_per_reg ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x03)); i += 4 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + + _mm256_storeu_pd( ( y0 + 0*num_elem_per_reg ), yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + } + + // Handling fringe cases or non-unit strided vectors + for ( ; i < n; i += 1 ) + { + *y0 += *x0; + + x0 += incx; + y0 += incy; + } +} + +void bli_caddv_zen_int + ( + conj_t conjx, + dim_t n, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m256 yv[12]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + float *x0 = (float *)x; + float *y0 = (float *)y; + + if( bli_is_conj( conjx ) ) + { + __m256 conjv = _mm256_set1_ps(1.0f); + if ( incx == 1 && incy ==1 ) + { + for ( ; (i + 47) < n; i += 48 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_ps + ( + conjv, + yv[0], + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_ps + ( + conjv, + yv[1], + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_ps + ( + conjv, + yv[2], + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_ps + ( + conjv, + yv[3], + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ) + ); + yv[4] = _mm256_fmsubadd_ps + ( + conjv, + yv[4], + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ) + ); + yv[5] = _mm256_fmsubadd_ps + ( + conjv, + yv[5], + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ) + ); + yv[6] = _mm256_fmsubadd_ps + ( + conjv, + yv[6], + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ) + ); + yv[7] = _mm256_fmsubadd_ps + ( + conjv, + yv[7], + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_ps( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_ps( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_ps( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_ps( y0 + 7*num_elem_per_reg, yv[7] ); + + yv[8] = _mm256_loadu_ps( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_ps( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_ps( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_ps( y0 + 11*num_elem_per_reg ); + + yv[8] = _mm256_fmsubadd_ps + ( + conjv, + yv[8], + _mm256_loadu_ps( x0 + 8*num_elem_per_reg ) + ); + yv[9] = _mm256_fmsubadd_ps + ( + conjv, + yv[9], + _mm256_loadu_ps( x0 + 9*num_elem_per_reg ) + ); + yv[10] = _mm256_fmsubadd_ps + ( + conjv, + yv[10], + _mm256_loadu_ps( x0 + 10*num_elem_per_reg ) + ); + yv[11] = _mm256_fmsubadd_ps + ( + conjv, + yv[11], + _mm256_loadu_ps( x0 + 11*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 8*num_elem_per_reg, yv[8] ); + _mm256_storeu_ps( y0 + 9*num_elem_per_reg, yv[9] ); + _mm256_storeu_ps( y0 + 10*num_elem_per_reg, yv[10] ); + _mm256_storeu_ps( y0 + 11*num_elem_per_reg, yv[11] ); + + x0 += 12 * num_elem_per_reg; + y0 += 12 * num_elem_per_reg; + } + + for ( ; (i + 31) < n; i += 32 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_ps + ( + conjv, + yv[0], + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_ps + ( + conjv, + yv[1], + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_ps + ( + conjv, + yv[2], + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_ps + ( + conjv, + yv[3], + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ) + ); + yv[4] = _mm256_fmsubadd_ps + ( + conjv, + yv[4], + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ) + ); + yv[5] = _mm256_fmsubadd_ps + ( + conjv, + yv[5], + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ) + ); + yv[6] = _mm256_fmsubadd_ps + ( + conjv, + yv[6], + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ) + ); + yv[7] = _mm256_fmsubadd_ps + ( + conjv, + yv[7], + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_ps( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_ps( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_ps( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_ps( y0 + 7*num_elem_per_reg, yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_ps + ( + conjv, + yv[0], + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_ps + ( + conjv, + yv[1], + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_ps + ( + conjv, + yv[2], + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_ps + ( + conjv, + yv[3], + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_ps + ( + conjv, + yv[0], + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_ps + ( + conjv, + yv[1], + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_ps + ( + conjv, + yv[0], + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ) + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + } + + // Handling fringe cases or non-unit strided vectors + for ( ; i < n; i += 1 ) + { + *y0 += *x0; + *(y0 + 1) -= *(x0 + 1); + + x0 += 2 * incx; + y0 += 2 * incy; + } + } + else + { + if ( incx == 1 && incy ==1 ) + { + for ( ; (i + 47) < n; i += 48 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_ps( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_ps( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_ps( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_ps( y0 + 7*num_elem_per_reg, yv[7] ); + + yv[8] = _mm256_loadu_ps( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_ps( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_ps( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_ps( y0 + 11*num_elem_per_reg ); + + yv[8] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 8*num_elem_per_reg ), + yv[8] + ); + yv[9] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 9*num_elem_per_reg ), + yv[9] + ); + yv[10] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 10*num_elem_per_reg ), + yv[10] + ); + yv[11] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 11*num_elem_per_reg ), + yv[11] + ); + + _mm256_storeu_ps( y0 + 8*num_elem_per_reg, yv[8] ); + _mm256_storeu_ps( y0 + 9*num_elem_per_reg, yv[9] ); + _mm256_storeu_ps( y0 + 10*num_elem_per_reg, yv[10] ); + _mm256_storeu_ps( y0 + 11*num_elem_per_reg, yv[11] ); + + x0 += 12 * num_elem_per_reg; + y0 += 12 * num_elem_per_reg; + } + + for ( ; (i + 31) < n; i += 32 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_ps( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_ps( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_ps( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_ps( y0 + 7*num_elem_per_reg, yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 3*num_elem_per_reg ), + yv[3] + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_ps( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_ps( y0 + 3*num_elem_per_reg, yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 1*num_elem_per_reg ), + yv[1] + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_ps( y0 + 1*num_elem_per_reg, yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // Loading input values + yv[0] = _mm256_loadu_ps( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_ps + ( + _mm256_loadu_ps( x0 + 0*num_elem_per_reg ), + yv[0] + ); + + _mm256_storeu_ps( y0 + 0*num_elem_per_reg, yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + } + + // Handling fringe cases or non-unit strided vectors + for ( ; i < n; i += 1 ) + { + *y0 += *x0; + *(y0 + 1) += *(x0 + 1); + + x0 += 2 * incx; + y0 += 2 * incy; + } + } +} + +void bli_zaddv_zen_int + ( + conj_t conjx, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + double *x0 = (double *)x; + double *y0 = (double *)y; + + if( bli_is_conj( conjx ) ) + { + __m256d yv[12]; + __m256d conjv = _mm256_set1_pd(1.0); + if ( incx == 1 && incy ==1 ) + { + for ( ; (i + 23) < n; i += 24 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_pd + ( + conjv, + yv[0], + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_pd + ( + conjv, + yv[1], + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_pd + ( + conjv, + yv[2], + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_pd + ( + conjv, + yv[3], + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ) + ); + yv[4] = _mm256_fmsubadd_pd + ( + conjv, + yv[4], + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ) + ); + yv[5] = _mm256_fmsubadd_pd + ( + conjv, + yv[5], + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ) + ); + yv[6] = _mm256_fmsubadd_pd + ( + conjv, + yv[6], + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ) + ); + yv[7] = _mm256_fmsubadd_pd + ( + conjv, + yv[7], + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_pd( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_pd( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_pd( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_pd( y0 + 7*num_elem_per_reg, yv[7] ); + + yv[8] = _mm256_loadu_pd( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_pd( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_pd( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_pd( y0 + 11*num_elem_per_reg ); + + yv[8] = _mm256_fmsubadd_pd + ( + conjv, + yv[8], + _mm256_loadu_pd( x0 + 8*num_elem_per_reg ) + ); + yv[9] = _mm256_fmsubadd_pd + ( + conjv, + yv[9], + _mm256_loadu_pd( x0 + 9*num_elem_per_reg ) + ); + yv[10] = _mm256_fmsubadd_pd + ( + conjv, + yv[10], + _mm256_loadu_pd( x0 + 10*num_elem_per_reg ) + ); + yv[11] = _mm256_fmsubadd_pd + ( + conjv, + yv[11], + _mm256_loadu_pd( x0 + 11*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 8*num_elem_per_reg, yv[8] ); + _mm256_storeu_pd( y0 + 9*num_elem_per_reg, yv[9] ); + _mm256_storeu_pd( y0 + 10*num_elem_per_reg, yv[10] ); + _mm256_storeu_pd( y0 + 11*num_elem_per_reg, yv[11] ); + + x0 += 12 * num_elem_per_reg; + y0 += 12 * num_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_pd + ( + conjv, + yv[0], + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_pd + ( + conjv, + yv[1], + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_pd + ( + conjv, + yv[2], + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_pd + ( + conjv, + yv[3], + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ) + ); + yv[4] = _mm256_fmsubadd_pd + ( + conjv, + yv[4], + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ) + ); + yv[5] = _mm256_fmsubadd_pd + ( + conjv, + yv[5], + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ) + ); + yv[6] = _mm256_fmsubadd_pd + ( + conjv, + yv[6], + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ) + ); + yv[7] = _mm256_fmsubadd_pd + ( + conjv, + yv[7], + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_pd( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_pd( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_pd( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_pd( y0 + 7*num_elem_per_reg, yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_pd + ( + conjv, + yv[0], + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_pd + ( + conjv, + yv[1], + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ) + ); + yv[2] = _mm256_fmsubadd_pd + ( + conjv, + yv[2], + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ) + ); + yv[3] = _mm256_fmsubadd_pd + ( + conjv, + yv[3], + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_pd + ( + conjv, + yv[0], + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ) + ); + yv[1] = _mm256_fmsubadd_pd + ( + conjv, + yv[1], + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; (i + 1) < n; i += 2 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_fmsubadd_pd + ( + conjv, + yv[0], + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ) + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + + _mm256_zeroupper(); + } + + __m128d x_vec, y_vec; + x_vec = _mm_setzero_pd(); + y_vec = _mm_setzero_pd(); + + for( ; i < n; i += 1 ) + { + x_vec = _mm_loadu_pd( x0 ); + y_vec = _mm_loadu_pd( y0 ); + + x_vec = _mm_shuffle_pd(x_vec, x_vec, 0x1); + y_vec = _mm_shuffle_pd(y_vec, y_vec, 0x1); + + y_vec =_mm_addsub_pd(y_vec, x_vec); + + y_vec = _mm_shuffle_pd(y_vec, y_vec, 0x1); + + _mm_storeu_pd(y0, y_vec); + + x0 += 2 * incx; + y0 += 2 * incy; + } + } + else + { + __m256d yv[12]; + if ( incx == 1 && incy ==1 ) + { + for ( ; (i + 23) < n; i += 24 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_pd( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_pd( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_pd( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_pd( y0 + 7*num_elem_per_reg, yv[7] ); + + yv[8] = _mm256_loadu_pd( y0 + 8*num_elem_per_reg ); + yv[9] = _mm256_loadu_pd( y0 + 9*num_elem_per_reg ); + yv[10] = _mm256_loadu_pd( y0 + 10*num_elem_per_reg ); + yv[11] = _mm256_loadu_pd( y0 + 11*num_elem_per_reg ); + + yv[8] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 8*num_elem_per_reg ), + yv[8] + ); + yv[9] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 9*num_elem_per_reg ), + yv[9] + ); + yv[10] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 10*num_elem_per_reg ), + yv[10] + ); + yv[11] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 11*num_elem_per_reg ), + yv[11] + ); + + _mm256_storeu_pd( y0 + 8*num_elem_per_reg, yv[8] ); + _mm256_storeu_pd( y0 + 9*num_elem_per_reg, yv[9] ); + _mm256_storeu_pd( y0 + 10*num_elem_per_reg, yv[10] ); + _mm256_storeu_pd( y0 + 11*num_elem_per_reg, yv[11] ); + + x0 += 12 * num_elem_per_reg; + y0 += 12 * num_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + _mm256_storeu_pd( y0 + 4*num_elem_per_reg, yv[4] ); + _mm256_storeu_pd( y0 + 5*num_elem_per_reg, yv[5] ); + _mm256_storeu_pd( y0 + 6*num_elem_per_reg, yv[6] ); + _mm256_storeu_pd( y0 + 7*num_elem_per_reg, yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + _mm256_storeu_pd( y0 + 2*num_elem_per_reg, yv[2] ); + _mm256_storeu_pd( y0 + 3*num_elem_per_reg, yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + yv[1] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + _mm256_storeu_pd( y0 + 1*num_elem_per_reg, yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; (i + 1) < n; i += 2 ) + { + // Loading input values + yv[0] = _mm256_loadu_pd( y0 + 0*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm256_add_pd + ( + _mm256_loadu_pd( x0 + 0*num_elem_per_reg ), + yv[0] + ); + + _mm256_storeu_pd( y0 + 0*num_elem_per_reg, yv[0] ); + + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + } + + __m128d x_vec, y_vec; + x_vec = _mm_setzero_pd(); + y_vec = _mm_setzero_pd(); + + for( ; i < n; i += 1 ) + { + x_vec = _mm_loadu_pd( x0 ); + y_vec = _mm_loadu_pd( y0 ); + + y_vec =_mm_add_pd(y_vec, x_vec); + + _mm_storeu_pd(y0, y_vec); + + x0 += 2 * incx; + y0 += 2 * incy; + } + } +} diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index 5c9e7af81b..e9b392aed5 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -838,7 +838,7 @@ static void bli_vec_search_double 1. The function results in undefined behaviour when NaN elements are present in the array. This behaviour is BLAS complaint. */ -void bli_damaxv_zen_int +BLIS_EXPORT_BLIS void bli_damaxv_zen_int ( dim_t n, double* restrict x, inc_t incx, @@ -846,7 +846,7 @@ void bli_damaxv_zen_int cntx_t* restrict cntx ) { - // Temproray pointer used inside the function + // Temporary pointer used inside the function double *x_temp = x; // Will hold the absolute largest element in the array diff --git a/kernels/zen/1/bli_axpbyv_zen_int.c b/kernels/zen/1/bli_axpbyv_zen_int.c index 23748ab992..5e2094a6d3 100644 --- a/kernels/zen/1/bli_axpbyv_zen_int.c +++ b/kernels/zen/1/bli_axpbyv_zen_int.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -55,7 +55,7 @@ typedef union * y := beta * y + alpha * conjx(x) * where, * x & y are single precision vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_saxpbyv_zen_int ( @@ -69,96 +69,225 @@ void bli_saxpbyv_zen_int ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call SSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> SSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> SSCALV + if ( bli_seq0( *alpha ) ) + { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 0, we call SSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> SSETV + // When alpha = 1 --> SCOPYV + // When alpha = !( 0 or 1 ) --> SSCAL2V + else if ( bli_seq0( *beta ) ) + { + bli_sscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> SADDV + // When alpha = !( 0 or 1 ) --> SAXPYV + else if ( bli_seq1( *beta ) ) + { + if( bli_seq1( *alpha ) ) + { + bli_saddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + } + else + { + bli_saxpyv_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + const dim_t n_elem_per_reg = 8; // number of elements per register const dim_t n_iter_unroll = 4; // num of registers per iteration - dim_t i; // iterator + dim_t i = 0; // iterator float* restrict x0; float* restrict y0; v8sf_t alphav; v8sf_t betav; - v8sf_t y0v, y1v, y2v, y3v; + v8sf_t yv[4]; - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) - return; + bool is_alpha_one = bli_seq1( *alpha ); // initialize local pointers x0 = x; y0 = y; - if ( incx == 1 && incy == 1 ) + if( incx == 1 && incy == 1 ) { - // broadcast alpha & beta to all elements of respective vector registers - alphav.v = _mm256_broadcast_ss( alpha ); - betav.v = _mm256_broadcast_ss( beta ); + // Broadcasting beta onto a YMM register + betav.v = _mm256_broadcast_ss( beta ); - // unrolling and vectorizing - for ( i = 0; ( i + 31 ) < n; i += 32 ) + if( is_alpha_one ) // Scale y with beta and add x to it + { + for ( ; ( i + 31 ) < n; i += 32 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + } + else { - // loading input y - y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_ps( betav.v, y0v.v ); - y1v.v = _mm256_mul_ps( betav.v, y1v.v ); - y2v.v = _mm256_mul_ps( betav.v, y2v.v ); - y3v.v = _mm256_mul_ps( betav.v, y3v.v ); - - // y := y' + alpha * x - y0v.v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - y0v.v - ); - y1v.v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), - y1v.v - ); - y2v.v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), - y2v.v - ); - y3v.v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), - y3v.v - ); - - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), y0v.v ); - _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), y1v.v ); - _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), y2v.v ); - _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), y3v.v ); - - x0 += n_elem_per_reg * n_iter_unroll; - y0 += n_elem_per_reg * n_iter_unroll; + // Broadcasting alpha onto a YMM register + alphav.v = _mm256_broadcast_ss( alpha ); + + for ( ; ( i + 31 ) < n; i += 32 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } } - + // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when // transitioning from AVX to SSE instructions (which may occur as soon // as the n_left cleanup loop below if BLIS is compiled with // -mfpmath=sse). _mm256_zeroupper(); + } - // if there are leftover iterations, perform them with scaler code + // Handling fringe cases or non-unit strides + if( is_alpha_one ) + { for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*x0); x0 += incx; y0 += incy; @@ -166,15 +295,15 @@ void bli_saxpbyv_zen_int } else { - // for non-unit increments, use scaler code - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*alpha) * (*x0); x0 += incx; y0 += incy; } } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } @@ -183,7 +312,7 @@ void bli_saxpbyv_zen_int * y := beta * y + alpha * conjx(x) * where, * x & y are double precision vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_daxpbyv_zen_int ( @@ -197,26 +326,99 @@ void bli_daxpbyv_zen_int ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call DSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> DSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> DSCALV + if ( bli_deq0( *alpha ) ) + { + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 0, we call DSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> DSETV + // When alpha = 1 --> DCOPYV + // When alpha = !( 0 or 1 ) --> DSCAL2V + else if ( bli_deq0( *beta ) ) + { + bli_dscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> DADDV + // When alpha = !( 0 or 1 ) --> DAXPYV + else if ( bli_deq1( *beta ) ) + { + if( bli_deq1( *alpha ) ) + { + bli_daddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + } + else + { + bli_daxpyv_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + const dim_t n_elem_per_reg = 4; // number of elements per register const dim_t n_iter_unroll = 4; // number of registers per iteration - dim_t i; // iterator + dim_t i = 0; // iterator double* restrict x0; double* restrict y0; v4df_t alphav; v4df_t betav; - v4df_t y0v, y1v, y2v, y3v; + v4df_t yv[4]; - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) - return; - } + bool is_alpha_one = bli_deq1( *alpha ); // initialize local pointers x0 = x; @@ -224,60 +426,109 @@ void bli_daxpbyv_zen_int if ( incx == 1 && incy == 1 ) { - // broadcast alpha & beta to all elements of respective vector registers - alphav.v = _mm256_broadcast_sd( alpha ); - betav.v = _mm256_broadcast_sd( beta ); + // Broadcasting beta onto a YMM register + betav.v = _mm256_broadcast_sd( beta ); - // unrolling and vectorizing - for ( i = 0; ( i + 15 ) < n; i += 16 ) + if( is_alpha_one ) // Scale y with beta and add x to it { - // loading input y - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y2v.v = _mm256_mul_pd( betav.v, y2v.v ); - y3v.v = _mm256_mul_pd( betav.v, y3v.v ); - - // y := y' + alpha * x - // := beta * y + alpha * x - y0v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), - y0v.v - ); - y1v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), - y1v.v - ); - y2v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), - y2v.v - ); - y3v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), - y3v.v - ); - - // storing the output - _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); - _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); - _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); - _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); - - x0 += n_elem_per_reg * n_iter_unroll; - y0 += n_elem_per_reg * n_iter_unroll; + for ( ; ( i + 15 ) < n; i += 16 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + } + else + { + // Broadcasting alpha onto a YMM register + alphav.v = _mm256_broadcast_sd( alpha ); + + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_pd( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } } // Issue vzeroupper instruction to clear upper lanes of ymm registers. @@ -286,11 +537,14 @@ void bli_daxpbyv_zen_int // as the n_left cleanup loop below if BLIS is compiled with // -mfpmath=sse). _mm256_zeroupper(); + } - // if there are leftover iterations, perform them with scaler code + // Handling fringe cases or non-unit strided inputs + if( is_alpha_one ) + { for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*x0); x0 += incx; y0 += incy; @@ -298,15 +552,16 @@ void bli_daxpbyv_zen_int } else { - // for non-unit increments, use scaler code - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*alpha) * (*x0); x0 += incx; y0 += incy; } } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } /** @@ -314,7 +569,7 @@ void bli_daxpbyv_zen_int * y := beta * y + alpha * conjx(x) * where, * x & y are simple complex vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_caxpbyv_zen_int ( @@ -328,390 +583,667 @@ void bli_caxpbyv_zen_int ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - const dim_t n_elem_per_reg = 8; // number of elements per register - dim_t i; // iterator + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call CSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> CSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> CSCALV + if ( bli_ceq0( *alpha ) ) + { + bli_cscalv_zen_int + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); - float* restrict x0; - float* restrict y0; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - float alphaR, alphaI, betaR, betaI; + // If beta is 0, we call CSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> CSETV + // When alpha = 1 --> CCOPYV + // When alpha = !( 0 or 1 ) --> CSCAL2V + else if ( bli_ceq0( *beta ) ) + { + bli_cscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); - __m256 alphaRv; - __m256 alphaIv; - __m256 betaRv; - __m256 betaIv; - __m256 xv[4]; - __m256 yv[4]; - __m256 iv[4]; // intermediate registers + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - conj_t conjx_use = conjx; - - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> CADDV + // When alpha = !( 0 or 1 ) --> CAXPYV + else if ( bli_ceq1( *beta ) ) { + if( bli_ceq1( *alpha ) ) + { + bli_caddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + } + else + { + bli_caxpyv_zen_int5 + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } - // initialize local pointers - x0 = ( float* ) x; - y0 = ( float* ) y; + dim_t i = 0; // iterator + + // Local pointers to x and y vectors + float* restrict x0; + float* restrict y0; + + // Boolean to check if alpha is 1 + bool is_alpha_one = bli_ceq1( *alpha ); + + // Variables to store real and imaginary components of alpha and beta + float alphaR, alphaI, betaR, betaI; + + // Initializing the local pointers + x0 = ( float* ) x; + y0 = ( float* ) y; alphaR = alpha->real; alphaI = alpha->imag; betaR = beta->real; betaI = beta->imag; + // In case of unit strides for x and y vectors if ( incx == 1 && incy == 1 ) { - //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- - // y = beta*y + alpha*x - // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) - // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI - // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + - // i ( bR.yI + bI.yR + aR.xI + aI.xR ) - - // SIMD Algorithm BLIS_NO_CONJUGATE - // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 - // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 - // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 - // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 - // arv = aR aR aR aR aR aR aR aR - // aiv = -aI aI -aI aI -aI aI -aI aI - // brv = bR bR bR bR bR bR bR bR - // biv = -bI bI -bI bI -bI bI -bI bI - - // step 1: iv = brv * iv - // step 2: shuffle yv -> yv' - // step 3: FMA yv = biv * yv' + iv - // step 4: iv = arv * xv - // step 5: shuffle xv -> xv' - // step 6: FMA yv = aiv * xv' + iv - - //---------- Scalar algorithm BLIS_CONJUGATE ------------- - // y = beta*y + alpha*conj(x) - // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) - // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI - // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + - // i ( bR.yI + bI.yR - aR.xI + aI.xR ) - - // SIMD Algorithm BLIS_CONJUGATE - // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 - // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 - // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 - // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 - // arv = aR -aR aR -aR aR -aR aR -aR - // aiv = aI aI aI aI aI aI aI aI - // brv = bR bR bR bR bR bR bR bR - // biv = -bI bI -bI bI -bI bI -bI bI - // - // step 1: iv = brv * iv - // step 2: shuffle yv -> yv' - // step 3: FMA yv = biv * yv' + iv - // step 4: iv = arv * xv - // step 5: shuffle xv -> xv' - // step 6: FMA yv = aiv * xv' + iv - - // broadcast alpha & beta to all elements of respective vector registers - if ( !bli_is_conj( conjx ) ) // If BLIS_NO_CONJUGATE + // Number of float precision elements in a YMM register + const dim_t n_elem_per_reg = 8; + + // Scratch registers + __m256 xv[4]; + __m256 yv[4]; + __m256 iv[4]; + + // Vectors to store real and imaginary components of beta + __m256 betaRv, betaIv; + + // Broadcasting real and imaginary components of beta onto the registers + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_broadcast_ss( &betaI ); + + if( is_alpha_one ) { - // alphaRv = aR aR aR aR aR aR aR aR - // alphaIv = -aI aI -aI aI -aI aI -aI aI - // betaRv = bR bR bR bR bR bR bR bR - // betaIv = -bI bI -bI bI -bI bI -bI bI - alphaRv = _mm256_broadcast_ss( &alphaR ); - alphaIv = _mm256_set_ps - ( - alphaI, -alphaI, alphaI, -alphaI, - alphaI, -alphaI, alphaI, -alphaI - ); - betaRv = _mm256_broadcast_ss( &betaR ); - betaIv = _mm256_set_ps - ( - betaI, -betaI, betaI, -betaI, - betaI, -betaI, betaI, -betaI - ); + __m256 reg_one = _mm256_set1_ps(1.0f); + iv[0] = _mm256_setzero_ps(); + + // Converting reg_one to have {1.0, -1.0, 1.0, -1.0, ...} + // This is needed in case we have t0 conjugate X vector + if( bli_is_conj( conjx ) ) + { + reg_one = _mm256_fmsubadd_ps( reg_one, iv[0], reg_one ); + } + // Processing 16 elements per loop, 8 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // Load the y vector, 16 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2 * n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3 * n_elem_per_reg ); + + // Load the x vector, 16 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( yv[2], 0xB1 ); + iv[3] = _mm256_permute_ps( yv[3], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + iv[2] = _mm256_mul_ps( betaIv, iv[2] ); + iv[3] = _mm256_mul_ps( betaIv, iv[3] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_ps( betaRv, yv[2], iv[2] ); + yv[3] = _mm256_fmaddsub_ps( betaRv, yv[3], iv[3] ); + + // Adding X conjugate to it + yv[0] = _mm256_fmadd_ps( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( reg_one, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( reg_one, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( reg_one, xv[3], yv[3] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + _mm256_storeu_ps( ( y0 + 3 * n_elem_per_reg ), yv[3] ); + + // Adjusting the pointers for the next iteration + y0 += 4 * n_elem_per_reg; + x0 += 4 * n_elem_per_reg; + } + + // Processing 12 elements per loop, 12 FMAs + for ( ; ( i + 11 ) < n; i += 12 ) + { + // Load the y vector, 12 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2 * n_elem_per_reg ); + + // Load the x vector, 12 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( yv[2], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + iv[2] = _mm256_mul_ps( betaIv, iv[2] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_ps( betaRv, yv[2], iv[2] ); + + // Adding X conjugate to it + yv[0] = _mm256_fmadd_ps( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( reg_one, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( reg_one, xv[2], yv[2] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + + // Adjusting the pointers for the next iteration + y0 += 3 * n_elem_per_reg; + x0 += 3 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 8 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Load the y vector, 8 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + + // Load the x vector, 8 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + + // Adding X conjugate to it + yv[0] = _mm256_fmadd_ps( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( reg_one, xv[1], yv[1] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + + // Adjusting the pointers for the next iteration + y0 += 2 * n_elem_per_reg; + x0 += 2 * n_elem_per_reg; + } + + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Load the y vector, 4 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + + // Load the x vector, 4 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + + // Adding X conjugate to it + yv[0] = _mm256_fmadd_ps( reg_one, xv[0], yv[0] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + + // Adjusting the pointers for the next iteration + y0 += 1 * n_elem_per_reg; + x0 += 1 * n_elem_per_reg; + } } else { - // alphaRv = aR -aR aR -aR aR -aR aR -aR - // alphaIv = aI aI aI aI aI aI aI aI - // betaRv = bR bR bR bR bR bR bR bR - // betaIv = -bI bI -bI bI -bI bI -bI bI - alphaRv = _mm256_set_ps - ( - -alphaR, alphaR, -alphaR, alphaR, - -alphaR, alphaR, -alphaR, alphaR - ); + // Scratch registers for storing real and imaginary components of alpha + __m256 alphaRv, alphaIv; + + iv[0] = _mm256_setzero_ps(); + + alphaRv = _mm256_broadcast_ss( &alphaR ); alphaIv = _mm256_broadcast_ss( &alphaI ); - betaRv = _mm256_broadcast_ss( &betaR ); - betaIv = _mm256_set_ps - ( - betaI, -betaI, betaI, -betaI, - betaI, -betaI, betaI, -betaI - ); - } - // Processing 16 elements per loop, 8 FMAs - for ( i = 0; ( i + 15 ) < n; i += 16 ) - { - // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_ps( betaRv, yv[0] ); - iv[1] = _mm256_mul_ps( betaRv, yv[1] ); - iv[2] = _mm256_mul_ps( betaRv, yv[2] ); - iv[3] = _mm256_mul_ps( betaRv, yv[3] ); - - // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 - yv[0] = _mm256_permute_ps( yv[0], 0xB1); - yv[1] = _mm256_permute_ps( yv[1], 0xB1); - yv[2] = _mm256_permute_ps( yv[2], 0xB1); - yv[3] = _mm256_permute_ps( yv[3], 0xB1); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); - yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); - yv[3] = _mm256_fmadd_ps( betaIv, yv[3], iv[3] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); - iv[3] = _mm256_mul_ps( alphaRv, xv[3] ); - - // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 - xv[0] = _mm256_permute_ps( xv[0], 0xB1); - xv[1] = _mm256_permute_ps( xv[1], 0xB1); - xv[2] = _mm256_permute_ps( xv[2], 0xB1); - xv[3] = _mm256_permute_ps( xv[3], 0xB1); - - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] ); - iv[2] = _mm256_fmadd_ps( alphaIv, xv[2], iv[2] ); - iv[3] = _mm256_fmadd_ps( alphaIv, xv[3], iv[3] ); - - yv[0] = _mm256_add_ps( yv[0], iv[0] ); - yv[1] = _mm256_add_ps( yv[1], iv[1] ); - yv[2] = _mm256_add_ps( yv[2], iv[2] ); - yv[3] = _mm256_add_ps( yv[3], iv[3] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); - _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); - _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), yv[3] ); - - y0 += 4*n_elem_per_reg; - x0 += 4*n_elem_per_reg; - } + // The changes on alphaRv and alphaIv are as follows : + // If conjugate is required: + // alphaRv = aR -aR aR -aR + // Else : + // alphaIv = -aI aI -aI aI + if( bli_is_conj( conjx ) ) + { + alphaRv = _mm256_fmsubadd_ps( iv[0], iv[0], alphaRv ); + } + else + { + alphaIv = _mm256_addsub_ps( iv[0], alphaIv ); + } - // Processing 12 elements per loop, 6 FMAs - for ( ; ( i + 11 ) < n; i += 12 ) - { - // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_ps( betaRv, yv[0] ); - iv[1] = _mm256_mul_ps( betaRv, yv[1] ); - iv[2] = _mm256_mul_ps( betaRv, yv[2] ); - - // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 - yv[0] = _mm256_permute_ps( yv[0], 0xB1); - yv[1] = _mm256_permute_ps( yv[1], 0xB1); - yv[2] = _mm256_permute_ps( yv[2], 0xB1); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); - yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); - - // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 - xv[0] = _mm256_permute_ps( xv[0], 0xB1); - xv[1] = _mm256_permute_ps( xv[1], 0xB1); - xv[2] = _mm256_permute_ps( xv[2], 0xB1); - - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] ); - iv[2] = _mm256_fmadd_ps( alphaIv, xv[2], iv[2] ); - - yv[0] = _mm256_add_ps( yv[0], iv[0] ); - yv[1] = _mm256_add_ps( yv[1], iv[1] ); - yv[2] = _mm256_add_ps( yv[2], iv[2] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); - _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); - - y0 += 3*n_elem_per_reg; - x0 += 3*n_elem_per_reg; - } + // Processing 16 elements per loop, 16 FMAs + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // Load the y vector, 16 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2 * n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3 * n_elem_per_reg ); + + // Load the x vector, 16 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3 * n_elem_per_reg ); - // Processing 16 elements per loop, 8 FMAs - for ( ; ( i + 7 ) < n; i += 8 ) - { - // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_ps( betaRv, yv[0] ); - iv[1] = _mm256_mul_ps( betaRv, yv[1] ); - - // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 - yv[0] = _mm256_permute_ps( yv[0], 0xB1); - yv[1] = _mm256_permute_ps( yv[1], 0xB1); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - - // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 - xv[0] = _mm256_permute_ps( xv[0], 0xB1); - xv[1] = _mm256_permute_ps( xv[1], 0xB1); - - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] ); - - yv[0] = _mm256_add_ps( yv[0], iv[0] ); - yv[1] = _mm256_add_ps( yv[1], iv[1] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); - - y0 += 2*n_elem_per_reg; - x0 += 2*n_elem_per_reg; + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( yv[2], 0xB1 ); + iv[3] = _mm256_permute_ps( yv[3], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + iv[2] = _mm256_mul_ps( betaIv, iv[2] ); + iv[3] = _mm256_mul_ps( betaIv, iv[3] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_ps( betaRv, yv[2], iv[2] ); + yv[3] = _mm256_fmaddsub_ps( betaRv, yv[3], iv[3] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 ... + iv[0] = _mm256_permute_ps( xv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( xv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( xv[2], 0xB1 ); + iv[3] = _mm256_permute_ps( xv[3], 0xB1 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_ps( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaRv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaRv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( alphaRv, xv[3], yv[3] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, iv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, iv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( alphaIv, iv[3], yv[3] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + _mm256_storeu_ps( ( y0 + 3 * n_elem_per_reg ), yv[3] ); + + // Adjusting the pointers for the next iteration + y0 += 4 * n_elem_per_reg; + x0 += 4 * n_elem_per_reg; + } + + // Processing 12 elements per loop, 12 FMAs + for ( ; ( i + 11 ) < n; i += 12 ) + { + // Load the y vector, 12 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2 * n_elem_per_reg ); + + // Load the x vector, 12 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( yv[2], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ...` + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + iv[2] = _mm256_mul_ps( betaIv, iv[2] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_ps( betaRv, yv[2], iv[2] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 ... + iv[0] = _mm256_permute_ps( xv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( xv[1], 0xB1 ); + iv[2] = _mm256_permute_ps( xv[2], 0xB1 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_ps( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaRv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaRv, xv[2], yv[2] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, iv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, iv[2], yv[2] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_ps( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + + // Adjusting the pointers for the next iteration + y0 += 3 * n_elem_per_reg; + x0 += 3 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 8 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Load the y vector, 8 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + yv[1] = _mm256_loadu_ps( y0 + 1 * n_elem_per_reg ); + + // Load the x vector, 8 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + xv[1] = _mm256_loadu_ps( x0 + 1 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + iv[1] = _mm256_mul_ps( betaIv, iv[1] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_ps( betaRv, yv[1], iv[1] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 + iv[0] = _mm256_permute_ps( xv[0], 0xB1 ); + iv[1] = _mm256_permute_ps( xv[1], 0xB1 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_ps( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaRv, xv[1], yv[1] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, iv[1], yv[1] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + _mm256_storeu_ps( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + + // Adjusting the pointers for the next iteration + y0 += 2 * n_elem_per_reg; + x0 += 2 * n_elem_per_reg; + } + + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Load the y vector, 4 elements in total + // yv = yR1 yI1 yR2 yI2 ... + yv[0] = _mm256_loadu_ps( y0 ); + + // Load the x vector, 4 elements in total + // xv = xR1 xI1 xR2 xI2 ... + xv[0] = _mm256_loadu_ps( x0 ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 ... + iv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_ps( betaIv, iv[0] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_ps( betaRv, yv[0], iv[0] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 ... + iv[0] = _mm256_permute_ps( xv[0], 0xB1 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_ps( alphaRv, xv[0], yv[0] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, iv[0], yv[0] ); + + // Storing the result to memory + _mm256_storeu_ps( ( y0 ), yv[0] ); + + // Adjusting the pointers for the next iteration + y0 += 1 * n_elem_per_reg; + x0 += 1 * n_elem_per_reg; + } } // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when - // transitioning from AVX to SSE instructions (which may occur as soon - // as the n_left cleanup loop below if BLIS is compiled with - // -mfpmath=sse). + // transitioning from AVX to SSE instructions. _mm256_zeroupper(); + } - if ( !bli_is_conj( conjx_use ) ) + // Handling fringe cases or non-unit-strides + if ( is_alpha_one ) + { + if( bli_is_conj( conjx ) ) { - for ( ; i < n ; ++i ) + for( ; i < n; i += 1 ) { - const float yRc = *y0; - const float yIc = *( y0 + 1 ); - - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) + - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); - - x0 += 2; - y0 += 2; + scomplex temp; + temp.real = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + (*x0); + temp.imag = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - (*(x0 + 1)); + + (*y0) = temp.real; + (*(y0 + 1)) = temp.imag; + + x0 += 2 * incx; + y0 += 2 * incy; } } else { - for ( ; i < n ; ++i ) + for( ; i < n; i += 1 ) { - const float yRc = *y0; - const float yIc = *( y0 + 1 ); - - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) - - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); - - x0 += 2; - y0 += 2; + scomplex temp; + temp.real = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + (*x0); + temp.imag = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + (*(x0 + 1)); + + (*y0) = temp.real; + (*(y0 + 1)) = temp.imag; + + x0 += 2 * incx; + y0 += 2 * incy; } } } else { - // for non-unit increments, use scaler code - if ( !bli_is_conj( conjx_use ) ) + if( bli_is_conj( conjx ) ) { - for ( i = 0; i < n ; ++i ) + for( ; i < n; i += 1 ) { - const float yRc = *y0; - const float yIc = *( y0 + 1 ); - - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) + - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); - - x0 += incx * 2; - y0 += incy * 2; + scomplex temp; + temp.real = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + temp.imag = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + (*y0) = temp.real; + (*(y0 + 1)) = temp.imag; + + x0 += 2 * incx; + y0 += 2 * incy; } } else { - for ( i = 0; i < n ; ++i ) + for( ; i < n; i += 1 ) { - const float yRc = *y0; - const float yIc = *( y0 + 1 ); - - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) - - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); - - x0 += incx * 2; - y0 += incy * 2; + scomplex temp; + temp.real = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + temp.imag = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + (*y0) = temp.real; + (*(y0 + 1)) = temp.imag; + + x0 += 2 * incx; + y0 += 2 * incy; } } } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } @@ -720,7 +1252,7 @@ void bli_caxpbyv_zen_int * y := beta * y + alpha * conjx(x) * where, * x & y are double complex vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_zaxpbyv_zen_int ( @@ -735,25 +1267,97 @@ void bli_zaxpbyv_zen_int { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - dim_t i = 0; // iterator + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call ZSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> ZSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> ZSCALV + if ( bli_ceq0( *alpha ) ) + { + bli_zscalv_zen_int + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); - // Local pointers to x and y vectors - double* restrict x0; - double* restrict y0; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - // Variables to store real and imaginary components of alpha and beta - double alphaR, alphaI, betaR, betaI; + // If beta is 0, we call ZSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> ZSETV + // When alpha = 1 --> ZCOPYV + // When alpha = !( 0 or 1 ) --> ZSCAL2V + else if ( bli_ceq0( *beta ) ) + { + bli_zscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); - // Local variable to store the conjugate type - conj_t conjx_use = conjx; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - /* If the vector dimension is zero, return early. */ - if ( bli_zero_dim1( n ) ) + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> ZADDV + // When alpha = !( 0 or 1 ) --> ZAXPYV + else if ( bli_ceq1( *beta ) ) { + if( bli_ceq1( *alpha ) ) + { + bli_zaddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + } + else + { + bli_zaxpyv_zen_int5 + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } + dim_t i = 0; // iterator + + // Local pointers to x and y vectors + double* restrict x0; + double* restrict y0; + + // Boolean to check if alpha is 1 + bool is_alpha_one = bli_zeq1( *alpha ); + + // Variables to store real and imaginary components of alpha and beta + double alphaR, alphaI, betaR, betaI; + // Initializing the local pointers x0 = ( double* ) x; y0 = ( double* ) y; @@ -763,16 +1367,6 @@ void bli_zaxpbyv_zen_int betaR = beta->real; betaI = beta->imag; - // Vectors to store real and imaginary components of beta - __m256d betaRv, betaIv; - - // Broadcasting real and imaginary components of beta onto the registers - betaRv = _mm256_broadcast_sd( &betaR ); - betaIv = _mm256_broadcast_sd( &betaI ); - - // Initializing a variable to classify the type of the computation - bool is_alpha_zero = bli_zeq0( *alpha ); - // In case of unit strides for x and y vectors if ( incx == 1 && incy == 1 ) { @@ -783,10 +1377,24 @@ void bli_zaxpbyv_zen_int __m256d xv[4]; __m256d yv[4]; __m256d iv[4]; + // Vectors to store real and imaginary components of beta + __m256d betaRv, betaIv; + + // Broadcasting real and imaginary components of beta onto the registers + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_broadcast_sd( &betaI ); - // In case of alpha being 0, we just need to scale y by beta - if( is_alpha_zero ) + if( is_alpha_one ) { + __m256d reg_one = _mm256_set1_pd(1.0); + iv[0] = _mm256_setzero_pd(); + + // Converting reg_one to have {1.0, -1.0, 1.0, -1.0} + // This is needed in case we have t0 conjugate X vector + if( bli_is_conj( conjx ) ) + { + reg_one = _mm256_fmsubadd_pd( reg_one, iv[0], reg_one ); + } // Processing 8 elements per loop, 8 FMAs for ( i = 0; ( i + 7 ) < n; i += 8 ) { @@ -797,21 +1405,30 @@ void bli_zaxpbyv_zen_int yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); yv[3] = _mm256_loadu_pd( y0 + 3 * n_elem_per_reg ); - // Permute the loaded vectors for the required compute - // xv = yI1 yR1 yI2 yR2 - xv[0] = _mm256_permute_pd( yv[0], 5 ); - xv[1] = _mm256_permute_pd( yv[1], 5 ); - xv[2] = _mm256_permute_pd( yv[2], 5 ); - xv[3] = _mm256_permute_pd( yv[3], 5 ); + // Load the x vector, 8 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); - // Scale the permuted vectors with imaginary component of beta + // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_mul_pd( betaIv, xv[0] ); - iv[1] = _mm256_mul_pd( betaIv, xv[1] ); - iv[2] = _mm256_mul_pd( betaIv, xv[2] ); - iv[3] = _mm256_mul_pd( betaIv, xv[3] ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); + iv[2] = _mm256_permute_pd( yv[2], 0x5 ); + iv[3] = _mm256_permute_pd( yv[3], 0x5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); + iv[2] = _mm256_mul_pd( betaIv, iv[2] ); + iv[3] = _mm256_mul_pd( betaIv, iv[3] ); - // Using fmaddsub to scale with real component of beta and sub/add to iv + // Using fmaddsub to scale with real component of beta + // and sub/add to iv // yv = betaRv * yv -/+ iv // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); @@ -819,6 +1436,12 @@ void bli_zaxpbyv_zen_int yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); yv[3] = _mm256_fmaddsub_pd( betaRv, yv[3], iv[3] ); + // Adding X conjugate to it + yv[0] = _mm256_fmadd_pd( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( reg_one, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( reg_one, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( reg_one, xv[3], yv[3] ); + // Storing the result to memory _mm256_storeu_pd( ( y0 ), yv[0] ); _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); @@ -839,17 +1462,24 @@ void bli_zaxpbyv_zen_int yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); - // Permute the loaded vectors for the required compute - // xv = yI1 yR1 yI2 yR2 - xv[0] = _mm256_permute_pd( yv[0], 5 ); - xv[1] = _mm256_permute_pd( yv[1], 5 ); - xv[2] = _mm256_permute_pd( yv[2], 5 ); + // Load the x vector, 6 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); - // Scale the permuted vectors with imaginary component of beta + // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_mul_pd( betaIv, xv[0] ); - iv[1] = _mm256_mul_pd( betaIv, xv[1] ); - iv[2] = _mm256_mul_pd( betaIv, xv[2] ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); + iv[2] = _mm256_permute_pd( yv[2], 0x5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); + iv[2] = _mm256_mul_pd( betaIv, iv[2] ); // Using fmaddsub to scale with real component of beta // and sub/add to iv @@ -859,6 +1489,11 @@ void bli_zaxpbyv_zen_int yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); + // Adding X conjugate to it + yv[0] = _mm256_fmadd_pd( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( reg_one, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( reg_one, xv[2], yv[2] ); + // Storing the result to memory _mm256_storeu_pd( ( y0 ), yv[0] ); _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); @@ -877,15 +1512,21 @@ void bli_zaxpbyv_zen_int yv[0] = _mm256_loadu_pd( y0 ); yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); - // Permute the loaded vectors for the required compute - // xv = yI1 yR1 yI2 yR2 - xv[0] = _mm256_permute_pd( yv[0], 5 ); - xv[1] = _mm256_permute_pd( yv[1], 5 ); + // Load the x vector, 4 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); // Scale the permuted vectors with imaginary component of beta - // iv = yI1.bI, yR1.bI, yI2.bI, yR2.bI - iv[0] = _mm256_mul_pd( betaIv, xv[0] ); - iv[1] = _mm256_mul_pd( betaIv, xv[1] ); + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); // Using fmaddsub to scale with real component of beta // and sub/add to iv @@ -894,6 +1535,10 @@ void bli_zaxpbyv_zen_int yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + // Adding X conjugate to it + yv[0] = _mm256_fmadd_pd( reg_one, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( reg_one, xv[1], yv[1] ); + // Storing the result to memory _mm256_storeu_pd( ( y0 ), yv[0] ); _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); @@ -903,20 +1548,25 @@ void bli_zaxpbyv_zen_int x0 += 2 * n_elem_per_reg; } - // Processing 2 elements per loop, 3 FMAs + // Processing 2 elements per loop, 2 FMAs for ( ; ( i + 1 ) < n; i += 2 ) { // Load the y vector, 2 elements in total // yv = yR1 yI1 yR2 yI2 yv[0] = _mm256_loadu_pd( y0 ); - // Permute the loaded vectors for the required compute - // xv = yI1 yR1 yI2 yR2 - xv[0] = _mm256_permute_pd( yv[0], 5 ); + // Load the x vector, 2 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); - // Scale the permuted vectors with imaginary component of beta + // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_mul_pd( betaIv, xv[0] ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); // Using fmaddsub to scale with real component of beta // and sub/add to iv @@ -924,6 +1574,9 @@ void bli_zaxpbyv_zen_int // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + // Adding X conjugate to it + yv[0] = _mm256_fmadd_pd( reg_one, xv[0], yv[0] ); + // Storing the result to memory _mm256_storeu_pd( ( y0 ), yv[0] ); @@ -932,7 +1585,6 @@ void bli_zaxpbyv_zen_int x0 += 1 * n_elem_per_reg; } } - else { // Scratch registers for storing real and imaginary components of alpha @@ -948,7 +1600,7 @@ void bli_zaxpbyv_zen_int // alphaRv = aR -aR aR -aR // Else : // alphaIv = -aI aI -aI aI - if( bli_is_conj( conjx_use ) ) + if( bli_is_conj( conjx ) ) { alphaRv = _mm256_fmsubadd_pd( iv[0], iv[0], alphaRv ); } @@ -960,14 +1612,14 @@ void bli_zaxpbyv_zen_int // Processing 8 elements per loop, 8 FMAs for ( i = 0; ( i + 7 ) < n; i += 8 ) { - // Load the y vector, 6 elements in total + // Load the y vector, 8 elements in total // yv = yR1 yI1 yR2 yI2 yv[0] = _mm256_loadu_pd( y0 ); yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); yv[3] = _mm256_loadu_pd( y0 + 3 * n_elem_per_reg ); - // Load the x vector, 6 elements in total + // Load the x vector, 8 elements in total // xv = xR1 xI1 xR2 xI2 xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); @@ -976,10 +1628,10 @@ void bli_zaxpbyv_zen_int // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_permute_pd( yv[0], 5 ); - iv[1] = _mm256_permute_pd( yv[1], 5 ); - iv[2] = _mm256_permute_pd( yv[2], 5 ); - iv[3] = _mm256_permute_pd( yv[3], 5 ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); + iv[2] = _mm256_permute_pd( yv[2], 0x5 ); + iv[3] = _mm256_permute_pd( yv[3], 0x5 ); // Scale the permuted vectors with imaginary component of beta // iv = betaIv * yv @@ -1000,10 +1652,10 @@ void bli_zaxpbyv_zen_int // Permute the loaded vectors from x for the required compute // xv' = xI1 xR1 xI2 xR2 - iv[0] = _mm256_permute_pd( xv[0], 5 ); - iv[1] = _mm256_permute_pd( xv[1], 5 ); - iv[2] = _mm256_permute_pd( xv[2], 5 ); - iv[3] = _mm256_permute_pd( xv[3], 5 ); + iv[0] = _mm256_permute_pd( xv[0], 0x5 ); + iv[1] = _mm256_permute_pd( xv[1], 0x5 ); + iv[2] = _mm256_permute_pd( xv[2], 0x5 ); + iv[3] = _mm256_permute_pd( xv[3], 0x5 ); // yv = alphaRv * xv + yv // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... @@ -1047,9 +1699,9 @@ void bli_zaxpbyv_zen_int // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_permute_pd( yv[0], 5 ); - iv[1] = _mm256_permute_pd( yv[1], 5 ); - iv[2] = _mm256_permute_pd( yv[2], 5 ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); + iv[2] = _mm256_permute_pd( yv[2], 0x5 ); // Scale the permuted vectors with imaginary component of beta // iv = betaIv * yv @@ -1068,9 +1720,9 @@ void bli_zaxpbyv_zen_int // Permute the loaded vectors from x for the required compute // xv' = xI1 xR1 xI2 xR2 - iv[0] = _mm256_permute_pd( xv[0], 5 ); - iv[1] = _mm256_permute_pd( xv[1], 5 ); - iv[2] = _mm256_permute_pd( xv[2], 5 ); + iv[0] = _mm256_permute_pd( xv[0], 0x5 ); + iv[1] = _mm256_permute_pd( xv[1], 0x5 ); + iv[2] = _mm256_permute_pd( xv[2], 0x5 ); // yv = alphaRv * xv + yv // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... @@ -1097,20 +1749,20 @@ void bli_zaxpbyv_zen_int // Processing 4 elements per loop, 4 FMAs for ( ; ( i + 3 ) < n; i += 4 ) { - // Load the y vector, 6 elements in total + // Load the y vector, 4 elements in total // yv = yR1 yI1 yR2 yI2 yv[0] = _mm256_loadu_pd( y0 ); yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); - // Load the x vector, 6 elements in total + // Load the x vector, 4 elements in total // xv = xR1 xI1 xR2 xI2 xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_permute_pd( yv[0], 5 ); - iv[1] = _mm256_permute_pd( yv[1], 5 ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); + iv[1] = _mm256_permute_pd( yv[1], 0x5 ); // Scale the permuted vectors with imaginary component of beta // iv = betaIv * yv @@ -1127,8 +1779,8 @@ void bli_zaxpbyv_zen_int // Permute the loaded vectors from x for the required compute // xv' = xI1 xR1 xI2 xR2 - iv[0] = _mm256_permute_pd( xv[0], 5 ); - iv[1] = _mm256_permute_pd( xv[1], 5 ); + iv[0] = _mm256_permute_pd( xv[0], 0x5 ); + iv[1] = _mm256_permute_pd( xv[1], 0x5 ); // yv = alphaRv * xv + yv // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... @@ -1149,20 +1801,20 @@ void bli_zaxpbyv_zen_int x0 += 2 * n_elem_per_reg; } - // Processing 2 elements per loop, 3 FMAs + // Processing 2 elements per loop, 2 FMAs for ( ; ( i + 1 ) < n; i += 2 ) { - // Load the y vector, 6 elements in total + // Load the y vector, 2 elements in total // yv = yR1 yI1 yR2 yI2 yv[0] = _mm256_loadu_pd( y0 ); - // Load the x vector, 6 elements in total + // Load the x vector, 2 elements in total // xv = xR1 xI1 xR2 xI2 xv[0] = _mm256_loadu_pd( x0 ); // Permute the vectors from y for the required compute // iv = yI1 yR1 yI2 yR2 - iv[0] = _mm256_permute_pd( yv[0], 5 ); + iv[0] = _mm256_permute_pd( yv[0], 0x5 ); // Scale the permuted vectors with imaginary component of beta // iv = betaIv * yv @@ -1177,7 +1829,7 @@ void bli_zaxpbyv_zen_int // Permute the loaded vectors from x for the required compute // xv' = xI1 xR1 xI2 xR2 - iv[0] = _mm256_permute_pd( xv[0], 5 ); + iv[0] = _mm256_permute_pd( xv[0], 0x5 ); // yv = alphaRv * xv + yv // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... @@ -1194,43 +1846,45 @@ void bli_zaxpbyv_zen_int y0 += 1 * n_elem_per_reg; x0 += 1 * n_elem_per_reg; } - } // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when // transitioning from AVX to SSE instructions. - _mm256_zeroupper(); + _mm256_zeroupper(); } // Scratch registers to be used in case of non-unit strides or fringe case of 1. __m128d x_elem, y_elem, x_perm, y_perm; - __m128d betaRv_128, betaIv_128; + __m128d betaRv, betaIv; - // Casting the lower 128-bit lanes from betaRv and betaIv to its 128-bit alternative - // registers to avoid redundant broadcasts. - betaRv_128 = _mm256_castpd256_pd128( betaRv ); - betaIv_128 = _mm256_castpd256_pd128( betaIv ); + // Broadcasting real and imag parts of beta onto 128 bit registers + betaRv = _mm_set1_pd( betaR ); + betaIv = _mm_set1_pd( betaI ); - // NOTE : We cannot similarly use _mm256_castpd256_pd128 to avoid loading alpha - // since alpha is loaded onto its YMM rgeisters on requirement basis. - // In case of directly falling to this compute(non-unit stride cases), - // alpha wouldn't have been loaded onto any YMM reigsters. - - // Changing betaIv_128 to { -bI bI } for the compute + // Changing betaIv to { -bI bI } for the compute x_elem = _mm_setzero_pd(); - betaIv_128 = _mm_addsub_pd( x_elem, betaIv_128 ); + betaIv = _mm_addsub_pd( x_elem, betaIv ); - // In case of alpha being 0, we just need to scale y by beta - if ( is_alpha_zero ) + if ( is_alpha_one ) { + __m128d reg_one = _mm_set1_pd(1.0); + + if( bli_is_conj( conjx ) ) + { + reg_one = _mm_addsub_pd( x_elem, reg_one ); + reg_one = _mm_permute_pd( reg_one, 0x1 ); + } + // Iterate over y, one element at a time for ( ; i < n; i += 1 ) { - // Load an element from y + // Load an element from x and y // y_elem = yR1 yI1 + // x_elem = xR1 xI1 y_elem = _mm_loadu_pd( y0 ); + x_elem = _mm_loadu_pd( x0 ); // Permute y in accordance to its compute // y_perm = yI1 yR1 @@ -1239,17 +1893,20 @@ void bli_zaxpbyv_zen_int // Scale y_perm by the imaginary // component of beta // y_perm = -yI1.bI, yR1.bI - y_perm = _mm_mul_pd( betaIv_128, y_perm ); + y_perm = _mm_mul_pd( betaIv, y_perm ); // Use fmadd to scale with real component of // beta and add with intermediate result // y_elem = yR1.bR - yI1.bI, yI1.bR + yR1.bI - y_elem = _mm_fmadd_pd( betaRv_128, y_elem, y_perm ); + y_elem = _mm_fmadd_pd( betaRv, y_elem, y_perm ); + + y_elem = _mm_fmadd_pd( reg_one, x_elem, y_elem ); // Storing the result to memory _mm_storeu_pd( y0, y_elem ); // Adjusting the pointer for the next iteration + x0 += incx * 2; y0 += incy * 2; } } @@ -1257,26 +1914,26 @@ void bli_zaxpbyv_zen_int { // Scratch registers to store real and imaginary components // of alpha onto XMM registers - __m128d alphaRv_128, alphaIv_128; + __m128d alphaRv, alphaIv; // Broadcasting real and imaginary components of alpha x_elem = _mm_setzero_pd(); - alphaRv_128 = _mm_loaddup_pd( &alphaR ); - alphaIv_128 = _mm_loaddup_pd( &alphaI ); + alphaRv = _mm_loaddup_pd( &alphaR ); + alphaIv = _mm_loaddup_pd( &alphaI ); - // The changes on alphaRv_128 and alphaIv_128 are as follows : + // The changes on alphaRv and alphaIv are as follows : // If conjugate is required: - // alphaRv_128 = aR -aR + // alphaRv = aR -aR // Else : - // alphaIv_128 = -aI aI - if( bli_is_conj( conjx_use ) ) + // alphaIv = -aI aI + if( bli_is_conj( conjx ) ) { - alphaRv_128 = _mm_addsub_pd( x_elem, alphaRv_128 ); - alphaRv_128 = _mm_permute_pd( alphaRv_128, 0x1 ); + alphaRv = _mm_addsub_pd( x_elem, alphaRv ); + alphaRv = _mm_permute_pd( alphaRv, 0x1 ); } else { - alphaIv_128 = _mm_addsub_pd( x_elem, alphaIv_128 ); + alphaIv = _mm_addsub_pd( x_elem, alphaIv ); } // Iterating over x and y vectors, on element at a time @@ -1298,8 +1955,8 @@ void bli_zaxpbyv_zen_int // component of beta and alpha // y_perm = -yI1.bI, yR1.bI // x_perm = -xI1.aI, xR1.aI - y_perm = _mm_mul_pd( betaIv_128, y_perm ); - x_perm = _mm_mul_pd( alphaIv_128, x_perm ); + y_perm = _mm_mul_pd( betaIv, y_perm ); + x_perm = _mm_mul_pd( alphaIv, x_perm ); // Use fmadd to scale with y_elem with // real component of beta and add with @@ -1307,8 +1964,8 @@ void bli_zaxpbyv_zen_int // for x_elem. // y_elem = yR1.bR - yI1.bI, yI1.bR + yR1.bI // x_elem = xR1.aR - xI1.aI, xI1.aR + xR1.aI - y_elem = _mm_fmadd_pd( betaRv_128, y_elem, y_perm ); - x_elem = _mm_fmadd_pd( alphaRv_128, x_elem, x_perm ); + y_elem = _mm_fmadd_pd( betaRv, y_elem, y_perm ); + x_elem = _mm_fmadd_pd( alphaRv, x_elem, x_perm ); // Add the computed x and y vectors, store on y. y_elem = _mm_add_pd( y_elem, x_elem ); diff --git a/kernels/zen/1/bli_axpbyv_zen_int10.c b/kernels/zen/1/bli_axpbyv_zen_int10.c index 02abdb4f2a..ee5523d63b 100644 --- a/kernels/zen/1/bli_axpbyv_zen_int10.c +++ b/kernels/zen/1/bli_axpbyv_zen_int10.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -55,7 +55,7 @@ typedef union * y := beta * y + alpha * conjx(x) * where, * x & y are single precision vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_saxpbyv_zen_int10 ( @@ -69,307 +69,629 @@ void bli_saxpbyv_zen_int10 ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - const dim_t n_elem_per_reg = 8; // number of elements per register - - dim_t i; // iterator - float* restrict x0; - float* restrict y0; + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call SSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> SSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> SSCALV + if ( bli_seq0( *alpha ) ) + { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); - v8sf_t alphav; - v8sf_t betav; - v8sf_t yv[10]; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + // If beta is 0, we call SSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> SSETV + // When alpha = 1 --> SCOPYV + // When alpha = !( 0 or 1 ) --> SSCAL2V + else if ( bli_seq0( *beta ) ) { + bli_sscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } - - // initialize local pointers - x0 = x; - y0 = y; - if ( incx == 1 && incy == 1 ) + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> SADDV + // When alpha = !( 0 or 1 ) --> SAXPYV + else if ( bli_seq1( *beta ) ) { - // broadcast alpha & beta to all elements of respective vector registers - alphav.v = _mm256_broadcast_ss( alpha ); - betav.v = _mm256_broadcast_ss( beta ); - - // Processing 80 elements per loop, 10 FMAs - for ( i = 0; ( i + 79 ) < n; i += 80 ) + if( bli_seq1( *alpha ) ) { - // loading input values - yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - yv[5].v = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); - yv[6].v = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); - yv[7].v = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); - yv[8].v = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); - yv[9].v = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); - - // y' := y := beta * y - yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); - yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); - yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); - yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); - yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); - yv[5].v = _mm256_mul_ps( betav.v, yv[5].v ); - yv[6].v = _mm256_mul_ps( betav.v, yv[6].v ); - yv[7].v = _mm256_mul_ps( betav.v, yv[7].v ); - yv[8].v = _mm256_mul_ps( betav.v, yv[8].v ); - yv[9].v = _mm256_mul_ps( betav.v, yv[9].v ); - - // y := y' + alpha * x - yv[0].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - yv[0].v - ); - yv[1].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), - yv[1].v - ); - yv[2].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), - yv[2].v - ); - yv[3].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), - yv[3].v - ); - yv[4].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), - yv[4].v - ); - yv[5].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 5*n_elem_per_reg ), - yv[5].v - ); - yv[6].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 6*n_elem_per_reg ), - yv[6].v - ); - yv[7].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 7*n_elem_per_reg ), - yv[7].v - ); - yv[8].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 8*n_elem_per_reg ), - yv[8].v - ); - yv[9].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 9*n_elem_per_reg ), - yv[9].v - ); - - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); - _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); - _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); - _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); - _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); - _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); - _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); - _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); - _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); - _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); - - x0 += 10 * n_elem_per_reg; - y0 += 10 * n_elem_per_reg; + bli_saddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); } - - // Processing 40 elements per loop, 5 FMAs - for ( ; ( i + 39 ) < n; i += 40 ) + else { - // loading input values - yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - - // y' := y := beta * y - yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); - yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); - yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); - yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); - yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); - - // y := y' + alpha * x - yv[0].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - yv[0].v - ); - yv[1].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), - yv[1].v - ); - yv[2].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), - yv[2].v - ); - yv[3].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), - yv[3].v - ); - yv[4].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), - yv[4].v - ); - - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); - _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); - _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); - _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); - _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); - - x0 += 5 * n_elem_per_reg; - y0 += 5 * n_elem_per_reg; + bli_saxpyv_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); } - // Processing 32 elements per loop, 4 FMAs - for ( ; ( i + 31 ) < n; i += 32 ) - { - // loading input values - yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - // y' := y := beta * y - yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); - yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); - yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); - yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); - - // y := y' + alpha * x - yv[0].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - yv[0].v - ); - yv[1].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), - yv[1].v - ); - yv[2].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), - yv[2].v - ); - yv[3].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), - yv[3].v - ); - - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); - _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); - _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); - _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); - - x0 += 4 * n_elem_per_reg; - y0 += 4 * n_elem_per_reg; - } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - // Processing 16 elements per loop, 2 FMAs - for ( ; ( i + 15 ) < n; i += 16 ) - { - // loading input values - yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - - // y' := y := beta * y - yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); - yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); - - // y := y' + alpha * x - yv[0].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - yv[0].v - ); - yv[1].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), - yv[1].v - ); - - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); - _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); - - x0 += 2 * n_elem_per_reg; - y0 += 2 * n_elem_per_reg; - } + const dim_t n_elem_per_reg = 8; // number of elements per register - // Processing 8 elements per loop, 1 FMA - for ( ; ( i + 7 ) < n; i += 8 ) - { - // loading input values - yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + dim_t i = 0; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t yv[10]; - // y' := y := beta * y - yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + bool is_alpha_one = bli_seq1( *alpha ); - // y := y' + alpha * x - yv[0].v = _mm256_fmadd_ps - ( - alphav.v, - _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), - yv[0].v - ); + // initialize local pointers + x0 = x; + y0 = y; - // storing the output - _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + if( incx == 1 && incy == 1 ) + { + // Broadcasting beta onto a YMM register + betav.v = _mm256_broadcast_ss( beta ); - x0 += 1 * n_elem_per_reg; - y0 += 1 * n_elem_per_reg; + if( is_alpha_one ) // Scale y with beta and add x to it + { + // Processing 80 elements per loop, 10 FMAs + for ( ; ( i + 79 ) < n; i += 80 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 40 elements per loop, 5 FMAs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 32 elements per loop, 4 FMAs + for ( ; ( i + 31 ) < n; i += 32 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 2 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 1 FMA + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading input values + yv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_ps + ( + betav.v, + _mm256_loadu_ps( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // Storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + } + else + { + // Broadcasting alpha onto a YMM register + alphav.v = _mm256_broadcast_ss( alpha ); + + // Processing 80 elements per loop, 10 FMAs and MULs + for ( i = 0; ( i + 79 ) < n; i += 80 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + yv[5].v = _mm256_mul_ps( betav.v, yv[5].v ); + yv[6].v = _mm256_mul_ps( betav.v, yv[6].v ); + yv[7].v = _mm256_mul_ps( betav.v, yv[7].v ); + yv[8].v = _mm256_mul_ps( betav.v, yv[8].v ); + yv[9].v = _mm256_mul_ps( betav.v, yv[9].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 40 elements per loop, 5 FMAs and MULs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 32 elements per loop, 4 FMAs and MULs + for ( ; ( i + 31 ) < n; i += 32 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 2 FMAs and MULs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 1 FMA and MUL + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } } // Issue vzeroupper instruction to clear upper lanes of ymm registers. @@ -378,11 +700,13 @@ void bli_saxpbyv_zen_int10 // as the n_left cleanup loop below if BLIS is compiled with // -mfpmath=sse). _mm256_zeroupper(); + } - // if there are leftover iterations, perform them with scaler code - for ( ; i < n; i++ ) + if( is_alpha_one ) + { + for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*x0); x0 += incx; y0 += incy; @@ -390,15 +714,15 @@ void bli_saxpbyv_zen_int10 } else { - // for non-unit increments, use scaler code - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*alpha) * (*x0); x0 += incx; y0 += incy; } } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } @@ -407,7 +731,7 @@ void bli_saxpbyv_zen_int10 * y := beta * y + alpha * conjx(x) * where, * x & y are double precision vectors of length n. - * alpha & beta are scalers. + * alpha & beta are scalars. */ void bli_daxpbyv_zen_int10 ( @@ -421,261 +745,629 @@ void bli_daxpbyv_zen_int10 ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - const dim_t n_elem_per_reg = 4; // number of elements per register - const dim_t n_iter_unroll = 10; // number of registers per iteration - dim_t i; // iterator - - double* restrict x0; - double* restrict y0; + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call DSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> DSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> DSCALV + if ( bli_deq0( *alpha ) ) + { + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); - v4df_t alphav; - v4df_t betav; - v4df_t y0v, y1v, y2v, y3v, y4v, y5v, y6v, y7v, y8v, y9v; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + // If beta is 0, we call DSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> DSETV + // When alpha = 1 --> DCOPYV + // When alpha = !( 0 or 1 ) --> DSCAL2V + else if ( bli_deq0( *beta ) ) { + bli_dscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } - // initialize local pointers - x0 = x; - y0 = y; - - if ( incx == 1 && incy == 1 ) + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> DADDV + // When alpha = !( 0 or 1 ) --> DAXPYV + else if ( bli_deq1( *beta ) ) { - // broadcast alpha & beta to all elements of respective vector registers - alphav.v = _mm256_broadcast_sd( alpha ); - betav.v = _mm256_broadcast_sd( beta ); - - // Using 10 FMAs per loop - for ( i = 0; ( i + 39 ) < n; i += 40 ) + if( bli_deq1( *alpha ) ) { - // loading input y - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - y5v.v = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); - y6v.v = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); - y7v.v = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); - y8v.v = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); - y9v.v = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y2v.v = _mm256_mul_pd( betav.v, y2v.v ); - y3v.v = _mm256_mul_pd( betav.v, y3v.v ); - y4v.v = _mm256_mul_pd( betav.v, y4v.v ); - y5v.v = _mm256_mul_pd( betav.v, y5v.v ); - y6v.v = _mm256_mul_pd( betav.v, y6v.v ); - y7v.v = _mm256_mul_pd( betav.v, y7v.v ); - y8v.v = _mm256_mul_pd( betav.v, y8v.v ); - y9v.v = _mm256_mul_pd( betav.v, y9v.v ); - - // y := y' + alpha * x - // := beta * y + alpha * x - y0v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), - y0v.v - ); - y1v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), - y1v.v - ); - y2v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), - y2v.v - ); - y3v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), - y3v.v - ); - y4v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), - y4v.v - ); - y5v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 5*n_elem_per_reg ), - y5v.v - ); - y6v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 6*n_elem_per_reg ), - y6v.v - ); - y7v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 7*n_elem_per_reg ), - y7v.v - ); - y8v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 8*n_elem_per_reg ), - y8v.v - ); - y9v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 9*n_elem_per_reg ), - y9v.v - ); - - // storing the output - _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); - _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); - _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); - _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); - _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); - _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), y5v.v ); - _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), y6v.v ); - _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), y7v.v ); - _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), y8v.v ); - _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), y9v.v ); - - x0 += n_elem_per_reg * n_iter_unroll; - y0 += n_elem_per_reg * n_iter_unroll; + bli_daddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); } - - // Using 5 FMAs per loop - for ( ; ( i + 19 ) < n; i += 20 ) + else { - // loading input y - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y2v.v = _mm256_mul_pd( betav.v, y2v.v ); - y3v.v = _mm256_mul_pd( betav.v, y3v.v ); - y4v.v = _mm256_mul_pd( betav.v, y4v.v ); - - // y := y' + alpha * x - // := beta * y + alpha * x - y0v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), - y0v.v - ); - y1v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), - y1v.v - ); - y2v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), - y2v.v - ); - y3v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), - y3v.v - ); - y4v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), - y4v.v - ); - - // storing the output - _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); - _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); - _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); - _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); - _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); - - x0 += n_elem_per_reg * 5; - y0 += n_elem_per_reg * 5; + bli_daxpyv_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); } - // Using 2 FMAs per loop - for ( ; ( i + 7 ) < n; i += 8 ) + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + const dim_t n_elem_per_reg = 4; // number of elements per register + + dim_t i = 0; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t yv[10]; + + bool is_alpha_one = bli_seq1( *alpha ); + + // initialize local pointers + x0 = x; + y0 = y; + + if( incx == 1 && incy == 1 ) + { + // Broadcasting beta onto a YMM register + betav.v = _mm256_broadcast_sd( beta ); + + if( is_alpha_one ) // Scale y with beta and add x to it { - // loading input y - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - - // y := y' + alpha * x - // := beta * y + alpha * x - y0v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), - y0v.v - ); - y1v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), - y1v.v - ); - - // storing the output - _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); - _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); - - x0 += n_elem_per_reg * 2; - y0 += n_elem_per_reg * 2; + // Processing 40 elements per loop, 10 FMAs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 20 elements per loop, 5 FMAs + for ( ; ( i + 19 ) < n; i += 20 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 4 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 2 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 4 elements per loop, 1 FMA + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Loading input values + yv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // y := beta * y + x + yv[0].v = _mm256_fmadd_pd + ( + betav.v, + _mm256_loadu_pd( y0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // Storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } } - - // Using 1 FMAs per loop - for ( ; ( i + 3 ) < n; i += 4 ) + else { - // loading input y - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - // y' := y := beta * y - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - - // y := y' + alpha * x - // := beta * y + alpha * x - y0v.v = _mm256_fmadd_pd - ( - alphav.v, - _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), - y0v.v - ); - - // storing the output - _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); - - x0 += n_elem_per_reg * 1; - y0 += n_elem_per_reg * 1; + // Broadcasting alpha onto a YMM register + alphav.v = _mm256_broadcast_sd( alpha ); + + // Processing 40 elements per loop, 10 FMAs and MULs + for ( i = 0; ( i + 39 ) < n; i += 40 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_pd( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_pd( betav.v, yv[4].v ); + yv[5].v = _mm256_mul_pd( betav.v, yv[5].v ); + yv[6].v = _mm256_mul_pd( betav.v, yv[6].v ); + yv[7].v = _mm256_mul_pd( betav.v, yv[7].v ); + yv[8].v = _mm256_mul_pd( betav.v, yv[8].v ); + yv[9].v = _mm256_mul_pd( betav.v, yv[9].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 20 elements per loop, 5 FMAs and MULs + for ( ; ( i + 19 ) < n; i += 20 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_pd( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_pd( betav.v, yv[4].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 4 FMAs and MULs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_pd( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 2 FMAs and MULs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // y' := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_pd( betav.v, yv[1].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 4 elements per loop, 1 FMA and MUL + for ( ; ( i + 3 ) < n; i += 4 ) + { + // loading input values + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_pd( betav.v, yv[0].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } } // Issue vzeroupper instruction to clear upper lanes of ymm registers. @@ -684,11 +1376,14 @@ void bli_daxpbyv_zen_int10 // as the n_left cleanup loop below if BLIS is compiled with // -mfpmath=sse). _mm256_zeroupper(); + } - // if there are leftover iterations, perform them with scaler code + // Handling fringe cases or non-unit strided inputs + if( is_alpha_one ) + { for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*x0); x0 += incx; y0 += incy; @@ -696,14 +1391,14 @@ void bli_daxpbyv_zen_int10 } else { - // for non-unit increments, use scaler code - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { - *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + *y0 = (*beta) * (*y0) + (*alpha) * (*x0); x0 += incx; y0 += incy; } } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index cc52b3dff7..23ae6b0ac6 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018 - 2020, The University of Texas at Austin. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -340,7 +340,7 @@ void bli_saxpyv_zen_int10 // ----------------------------------------------------------------------------- -void bli_daxpyv_zen_int10 +BLIS_EXPORT_BLIS void bli_daxpyv_zen_int10 ( conj_t conjx, dim_t n, @@ -360,9 +360,9 @@ void bli_daxpyv_zen_int10 double* restrict y0 = y; __m256d alphav; - __m256d xv[13]; - __m256d yv[13]; - __m256d zv[13]; + __m256d xv[4]; + __m256d yv[4]; + __m256d zv[4]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) @@ -380,151 +380,7 @@ void bli_daxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); - for (i = 0; (i + 51) < n; i += 52) - { - // 52 elements will be processed per loop; 13 FMAs will run per loop. - xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); - xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); - xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); - xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); - xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); - xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); - xv[7] = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); - xv[8] = _mm256_loadu_pd(x0 + 8 * n_elem_per_reg); - xv[9] = _mm256_loadu_pd(x0 + 9 * n_elem_per_reg); - xv[10] = _mm256_loadu_pd(x0 + 10 * n_elem_per_reg); - xv[11] = _mm256_loadu_pd(x0 + 11 * n_elem_per_reg); - xv[12] = _mm256_loadu_pd(x0 + 12 * n_elem_per_reg); - - yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); - yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); - yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); - yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); - yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); - yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); - yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); - yv[7] = _mm256_loadu_pd(y0 + 7 * n_elem_per_reg); - yv[8] = _mm256_loadu_pd(y0 + 8 * n_elem_per_reg); - yv[9] = _mm256_loadu_pd(y0 + 9 * n_elem_per_reg); - yv[10] = _mm256_loadu_pd(y0 + 10 * n_elem_per_reg); - yv[11] = _mm256_loadu_pd(y0 + 11 * n_elem_per_reg); - yv[12] = _mm256_loadu_pd(y0 + 12 * n_elem_per_reg); - - zv[0] = _mm256_fmadd_pd(xv[0], alphav, yv[0]); - zv[1] = _mm256_fmadd_pd(xv[1], alphav, yv[1]); - zv[2] = _mm256_fmadd_pd(xv[2], alphav, yv[2]); - zv[3] = _mm256_fmadd_pd(xv[3], alphav, yv[3]); - zv[4] = _mm256_fmadd_pd(xv[4], alphav, yv[4]); - zv[5] = _mm256_fmadd_pd(xv[5], alphav, yv[5]); - zv[6] = _mm256_fmadd_pd(xv[6], alphav, yv[6]); - zv[7] = _mm256_fmadd_pd(xv[7], alphav, yv[7]); - zv[8] = _mm256_fmadd_pd(xv[8], alphav, yv[8]); - zv[9] = _mm256_fmadd_pd(xv[9], alphav, yv[9]); - zv[10] = _mm256_fmadd_pd(xv[10], alphav, yv[10]); - zv[11] = _mm256_fmadd_pd(xv[11], alphav, yv[11]); - zv[12] = _mm256_fmadd_pd(xv[12], alphav, yv[12]); - - _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), zv[0]); - _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), zv[1]); - _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), zv[2]); - _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), zv[3]); - _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), zv[4]); - _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), zv[5]); - _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), zv[6]); - _mm256_storeu_pd((y0 + 7 * n_elem_per_reg), zv[7]); - _mm256_storeu_pd((y0 + 8 * n_elem_per_reg), zv[8]); - _mm256_storeu_pd((y0 + 9 * n_elem_per_reg), zv[9]); - _mm256_storeu_pd((y0 + 10 * n_elem_per_reg), zv[10]); - _mm256_storeu_pd((y0 + 11 * n_elem_per_reg), zv[11]); - _mm256_storeu_pd((y0 + 12 * n_elem_per_reg), zv[12]); - - x0 += 13 * n_elem_per_reg; - y0 += 13 * n_elem_per_reg; - } - - for ( ; (i + 39) < n; i += 40 ) - { - // 40 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - zv[5] = _mm256_fmadd_pd( xv[5], alphav, yv[5] ); - zv[6] = _mm256_fmadd_pd( xv[6], alphav, yv[6] ); - zv[7] = _mm256_fmadd_pd( xv[7], alphav, yv[7] ); - zv[8] = _mm256_fmadd_pd( xv[8], alphav, yv[8] ); - zv[9] = _mm256_fmadd_pd( xv[9], alphav, yv[9] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (y0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (y0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (y0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (y0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (y0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - for ( ; (i + 19) < n; i += 20 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) + for ( i = 0; ( i + 15 ) < n; i += 16 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -550,7 +406,7 @@ void bli_daxpyv_zen_int10 y0 += 4*n_elem_per_reg; } - for ( ; i + 7 < n; i += 8 ) + for ( ; ( i + 7 ) < n; i += 8 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -568,7 +424,7 @@ void bli_daxpyv_zen_int10 y0 += 2*n_elem_per_reg; } - for ( ; i + 3 < n; i += 4 ) + for ( ; ( i + 3 ) < n; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -960,7 +816,7 @@ void bli_zaxpyv_zen_int5 // Prefetch X vector to the L1 cache // as these elements will be need anyway - _mm_prefetch(x0, _MM_HINT_T1); + _mm_prefetch((char const*)x0, _MM_HINT_T1); // Broadcast the alpha scalar to all elements of a vector register. if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE @@ -1066,8 +922,8 @@ void bli_zaxpyv_zen_int5 xv[6] = _mm256_permute_pd(xv[6], 5); // Prefetch X and Y vectors to the L1 cache - _mm_prefetch(x0 + distance, _MM_HINT_T1); - _mm_prefetch(y0 + distance, _MM_HINT_T1); + _mm_prefetch((char const*)(x0 + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(y0 + distance), _MM_HINT_T1); // alphaIv = -aI aI -aI aI // yv = ar*xv + yv diff --git a/kernels/zen/1/bli_copyv_zen_int.c b/kernels/zen/1/bli_copyv_zen_int.c index d940cefc52..bae19e01f1 100644 --- a/kernels/zen/1/bli_copyv_zen_int.c +++ b/kernels/zen/1/bli_copyv_zen_int.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -341,6 +341,221 @@ void bli_dcopyv_zen_int AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) } +void bli_ccopyv_zen_int +( + conj_t conjx, + dim_t n, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + // Setting the local pointers and iterator + dim_t i = 0; + scomplex *x0 = x; + scomplex *y0 = y; + + // Handling conjugate separately + if ( bli_is_conj( conjx ) ) + { + if ( incx == 1 && incy == 1 ) + { + const dim_t n_elem_per_reg = 4; + __m256 x_vec[8]; + + __m256 conj_reg = _mm256_setr_ps(1, -1, 1, -1, 1, -1, 1, -1); + + for (; (i + 31) < n; i += 32) + { + /* 8 float values = 4 float complex values are loaded*/ + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + x_vec[4] = _mm256_loadu_ps((float *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_ps((float *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_ps((float *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_ps((float *)(x0 + 7 * n_elem_per_reg)); + + /* Perform conjugation by multiplying the imaginary + part with -1 and real part with 1*/ + x_vec[0] = _mm256_mul_ps(x_vec[0], conj_reg); + x_vec[1] = _mm256_mul_ps(x_vec[1], conj_reg); + x_vec[2] = _mm256_mul_ps(x_vec[2], conj_reg); + x_vec[3] = _mm256_mul_ps(x_vec[3], conj_reg); + x_vec[4] = _mm256_mul_ps(x_vec[4], conj_reg); + x_vec[5] = _mm256_mul_ps(x_vec[5], conj_reg); + x_vec[6] = _mm256_mul_ps(x_vec[6], conj_reg); + x_vec[7] = _mm256_mul_ps(x_vec[7], conj_reg); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + _mm256_storeu_ps((float *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_ps((float *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_ps((float *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_ps((float *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_ps(x_vec[0], conj_reg); + x_vec[1] = _mm256_mul_ps(x_vec[1], conj_reg); + x_vec[2] = _mm256_mul_ps(x_vec[2], conj_reg); + x_vec[3] = _mm256_mul_ps(x_vec[3], conj_reg); + + x0 += 4 * n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + y0 += 4 * n_elem_per_reg; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + + x0 += 2 * n_elem_per_reg; + + x_vec[0] = _mm256_mul_ps(x_vec[0], conj_reg); + x_vec[1] = _mm256_mul_ps(x_vec[1], conj_reg); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + + y0 += 2 * n_elem_per_reg; + } + + for (; (i + 3) < n; i += 4) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + + x_vec[0] = _mm256_mul_ps(x_vec[0], conj_reg); + + x0 += n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + + y0 += n_elem_per_reg; + } + + } + + // Handling fringe cases or non-unit strided inputs + for (; i < n; i += 1) + { + scomplex temp = *x0; + temp.imag = -temp.imag; + *y0 = temp; + + x0 += incx; + y0 += incy; + } + } + else + { + if (incx == 1 && incy == 1) + { + const dim_t n_elem_per_reg = 4; + __m256 x_vec[8]; + + for (; (i + 31) < n; i += 32) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + x_vec[4] = _mm256_loadu_ps((float *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_ps((float *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_ps((float *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_ps((float *)(x0 + 7 * n_elem_per_reg)); + + x0 += 8 * n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + _mm256_storeu_ps((float *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_ps((float *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_ps((float *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_ps((float *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + y0 += 8 * n_elem_per_reg; + } + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + x0 += 4 * n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + y0 += 4 * n_elem_per_reg; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + + x0 += 2 * n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + + y0 += 2 * n_elem_per_reg; + } + + for (; (i + 3) < n; i += 4) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + + x0 += n_elem_per_reg; + + _mm256_storeu_ps((float *)y0, x_vec[0]); + + y0 += n_elem_per_reg; + } + + } + for (; i < n; i += 1) + { + *y0 = *x0; + + x0 += incx; + y0 += incy; + } + } +} + void bli_zcopyv_zen_int ( conj_t conjx, diff --git a/kernels/zen/1/bli_scal2v_zen_int.c b/kernels/zen/1/bli_scal2v_zen_int.c index 1c91138cf0..6ab9536877 100644 --- a/kernels/zen/1/bli_scal2v_zen_int.c +++ b/kernels/zen/1/bli_scal2v_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,417 @@ #include "blis.h" #include -/* This kernel performs y := alpha * conjx(x) +// This kernel performs y := alpha * conjx(x) +void bli_sscal2v_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, return early. + if (bli_zero_dim1(n)) + return; + + if (PASTEMAC(s, eq0)(*alpha)) + { + /* If alpha is zero, use setv. */ + float *zero = PASTEMAC(s, 0); + + bli_ssetv_zen_int + ( + BLIS_NO_CONJUGATE, + n, + zero, + y, incy, + cntx + ); + + return; + } + else if (PASTEMAC(s, eq1)(*alpha)) + { + /* If alpha is one, use copyv. */ + bli_scopyv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + + return; + } + + dim_t i = 0; + float *x0 = x; + float *y0 = y; + + if (incx == 1 && incy == 1) + { + __m256 x_vec[12], alphav; + + alphav = _mm256_broadcast_ss(alpha); + + const dim_t n_elem_per_reg = 8; + + for (; (i + 95) < n; i += 96) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_ps(x_vec[0], alphav); + x_vec[1] = _mm256_mul_ps(x_vec[1], alphav); + x_vec[2] = _mm256_mul_ps(x_vec[2], alphav); + x_vec[3] = _mm256_mul_ps(x_vec[3], alphav); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x_vec[4] = _mm256_loadu_ps((float *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_ps((float *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_ps((float *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_ps((float *)(x0 + 7 * n_elem_per_reg)); + + x_vec[4] = _mm256_mul_ps(x_vec[4], alphav); + x_vec[5] = _mm256_mul_ps(x_vec[5], alphav); + x_vec[6] = _mm256_mul_ps(x_vec[6], alphav); + x_vec[7] = _mm256_mul_ps(x_vec[7], alphav); + + _mm256_storeu_ps((float *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_ps((float *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_ps((float *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_ps((float *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + x_vec[8] = _mm256_loadu_ps((float *)(x0 + 8 * n_elem_per_reg)); + x_vec[9] = _mm256_loadu_ps((float *)(x0 + 9 * n_elem_per_reg)); + x_vec[10] = _mm256_loadu_ps((float *)(x0 + 10 * n_elem_per_reg)); + x_vec[11] = _mm256_loadu_ps((float *)(x0 + 11 * n_elem_per_reg)); + + x_vec[8] = _mm256_mul_ps(x_vec[8], alphav); + x_vec[9] = _mm256_mul_ps(x_vec[9], alphav); + x_vec[10] = _mm256_mul_ps(x_vec[10], alphav); + x_vec[11] = _mm256_mul_ps(x_vec[11], alphav); + + _mm256_storeu_ps((float *)(y0 + 8 * n_elem_per_reg), x_vec[8]); + _mm256_storeu_ps((float *)(y0 + 9 * n_elem_per_reg), x_vec[9]); + _mm256_storeu_ps((float *)(y0 + 10 * n_elem_per_reg), x_vec[10]); + _mm256_storeu_ps((float *)(y0 + 11 * n_elem_per_reg), x_vec[11]); + + x0 += 96; + y0 += 96; + } + + for (; (i + 63) < n; i += 64) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_ps(x_vec[0], alphav); + x_vec[1] = _mm256_mul_ps(x_vec[1], alphav); + x_vec[2] = _mm256_mul_ps(x_vec[2], alphav); + x_vec[3] = _mm256_mul_ps(x_vec[3], alphav); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x_vec[4] = _mm256_loadu_ps((float *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_ps((float *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_ps((float *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_ps((float *)(x0 + 7 * n_elem_per_reg)); + + x_vec[4] = _mm256_mul_ps(x_vec[4], alphav); + x_vec[5] = _mm256_mul_ps(x_vec[5], alphav); + x_vec[6] = _mm256_mul_ps(x_vec[6], alphav); + x_vec[7] = _mm256_mul_ps(x_vec[7], alphav); + + _mm256_storeu_ps((float *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_ps((float *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_ps((float *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_ps((float *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + x0 += 64; + y0 += 64; + } + + for (; (i + 31) < n; i += 32) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_ps(x_vec[0], alphav); + x_vec[1] = _mm256_mul_ps(x_vec[1], alphav); + x_vec[2] = _mm256_mul_ps(x_vec[2], alphav); + x_vec[3] = _mm256_mul_ps(x_vec[3], alphav); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x0 += 32; + y0 += 32; + } + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + + x_vec[0] = _mm256_mul_ps(x_vec[0], alphav); + x_vec[1] = _mm256_mul_ps(x_vec[1], alphav); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), x_vec[1]); + + x0 += 16; + y0 += 16; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + + x_vec[0] = _mm256_mul_ps(x_vec[0], alphav); + + _mm256_storeu_ps((float *)y0, x_vec[0]); + + x0 += 8; + y0 += 8; + } + + _mm256_zeroupper(); + } + + // Handling fringe case or non-unit strides + for (; i < n; i++) + { + *y0 = (*alpha) * (*x0); + x0 += incx; + y0 += incy; + } +} + +// This kernel performs y := alpha * conjx(x) +void bli_dscal2v_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, return early. + if (bli_zero_dim1(n)) + return; + + if (PASTEMAC(d, eq0)(*alpha)) + { + /* If alpha is zero, use setv. */ + double *zero = PASTEMAC(d, 0); + + bli_dsetv_zen_int + ( + BLIS_NO_CONJUGATE, + n, + zero, + y, incy, + cntx + ); + + return; + } + else if (PASTEMAC(d, eq1)(*alpha)) + { + /* If alpha is one, use copyv. */ + bli_dcopyv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + + return; + } + + dim_t i = 0; + double *x0 = x; + double *y0 = y; + + if (incx == 1 && incy == 1) + { + __m256d x_vec[12], alphav; + + alphav = _mm256_broadcast_sd(alpha); + + const dim_t n_elem_per_reg = 4; + + for (; (i + 47) < n; i += 48) + { + x_vec[0] = _mm256_loadu_pd((double *)x0); + x_vec[1] = _mm256_loadu_pd((double *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_pd((double *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_pd((double *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_pd(x_vec[0], alphav); + x_vec[1] = _mm256_mul_pd(x_vec[1], alphav); + x_vec[2] = _mm256_mul_pd(x_vec[2], alphav); + x_vec[3] = _mm256_mul_pd(x_vec[3], alphav); + + _mm256_storeu_pd((double *)y0, x_vec[0]); + _mm256_storeu_pd((double *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_pd((double *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_pd((double *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x_vec[4] = _mm256_loadu_pd((double *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_pd((double *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_pd((double *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_pd((double *)(x0 + 7 * n_elem_per_reg)); + + x_vec[4] = _mm256_mul_pd(x_vec[4], alphav); + x_vec[5] = _mm256_mul_pd(x_vec[5], alphav); + x_vec[6] = _mm256_mul_pd(x_vec[6], alphav); + x_vec[7] = _mm256_mul_pd(x_vec[7], alphav); + + _mm256_storeu_pd((double *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_pd((double *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_pd((double *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_pd((double *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + x_vec[8] = _mm256_loadu_pd((double *)(x0 + 8 * n_elem_per_reg)); + x_vec[9] = _mm256_loadu_pd((double *)(x0 + 9 * n_elem_per_reg)); + x_vec[10] = _mm256_loadu_pd((double *)(x0 + 10 * n_elem_per_reg)); + x_vec[11] = _mm256_loadu_pd((double *)(x0 + 11 * n_elem_per_reg)); + + x_vec[8] = _mm256_mul_pd(x_vec[8], alphav); + x_vec[9] = _mm256_mul_pd(x_vec[9], alphav); + x_vec[10] = _mm256_mul_pd(x_vec[10], alphav); + x_vec[11] = _mm256_mul_pd(x_vec[11], alphav); + + _mm256_storeu_pd((double *)(y0 + 8 * n_elem_per_reg), x_vec[8]); + _mm256_storeu_pd((double *)(y0 + 9 * n_elem_per_reg), x_vec[9]); + _mm256_storeu_pd((double *)(y0 + 10 * n_elem_per_reg), x_vec[10]); + _mm256_storeu_pd((double *)(y0 + 11 * n_elem_per_reg), x_vec[11]); + + x0 += 48; + y0 += 48; + } + + for (; (i + 31) < n; i += 32) + { + x_vec[0] = _mm256_loadu_pd((double *)x0); + x_vec[1] = _mm256_loadu_pd((double *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_pd((double *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_pd((double *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_pd(x_vec[0], alphav); + x_vec[1] = _mm256_mul_pd(x_vec[1], alphav); + x_vec[2] = _mm256_mul_pd(x_vec[2], alphav); + x_vec[3] = _mm256_mul_pd(x_vec[3], alphav); + + _mm256_storeu_pd((double *)y0, x_vec[0]); + _mm256_storeu_pd((double *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_pd((double *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_pd((double *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x_vec[4] = _mm256_loadu_pd((double *)(x0 + 4 * n_elem_per_reg)); + x_vec[5] = _mm256_loadu_pd((double *)(x0 + 5 * n_elem_per_reg)); + x_vec[6] = _mm256_loadu_pd((double *)(x0 + 6 * n_elem_per_reg)); + x_vec[7] = _mm256_loadu_pd((double *)(x0 + 7 * n_elem_per_reg)); + + x_vec[4] = _mm256_mul_pd(x_vec[4], alphav); + x_vec[5] = _mm256_mul_pd(x_vec[5], alphav); + x_vec[6] = _mm256_mul_pd(x_vec[6], alphav); + x_vec[7] = _mm256_mul_pd(x_vec[7], alphav); + + _mm256_storeu_pd((double *)(y0 + 4 * n_elem_per_reg), x_vec[4]); + _mm256_storeu_pd((double *)(y0 + 5 * n_elem_per_reg), x_vec[5]); + _mm256_storeu_pd((double *)(y0 + 6 * n_elem_per_reg), x_vec[6]); + _mm256_storeu_pd((double *)(y0 + 7 * n_elem_per_reg), x_vec[7]); + + x0 += 32; + y0 += 32; + } + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_pd((double *)x0); + x_vec[1] = _mm256_loadu_pd((double *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_pd((double *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_pd((double *)(x0 + 3 * n_elem_per_reg)); + + x_vec[0] = _mm256_mul_pd(x_vec[0], alphav); + x_vec[1] = _mm256_mul_pd(x_vec[1], alphav); + x_vec[2] = _mm256_mul_pd(x_vec[2], alphav); + x_vec[3] = _mm256_mul_pd(x_vec[3], alphav); + + _mm256_storeu_pd((double *)y0, x_vec[0]); + _mm256_storeu_pd((double *)(y0 + n_elem_per_reg), x_vec[1]); + _mm256_storeu_pd((double *)(y0 + 2 * n_elem_per_reg), x_vec[2]); + _mm256_storeu_pd((double *)(y0 + 3 * n_elem_per_reg), x_vec[3]); + + x0 += 16; + y0 += 16; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_pd((double *)x0); + x_vec[1] = _mm256_loadu_pd((double *)(x0 + n_elem_per_reg)); + + x_vec[0] = _mm256_mul_pd(x_vec[0], alphav); + x_vec[1] = _mm256_mul_pd(x_vec[1], alphav); + + _mm256_storeu_pd((double *)y0, x_vec[0]); + _mm256_storeu_pd((double *)(y0 + n_elem_per_reg), x_vec[1]); + + x0 += 8; + y0 += 8; + } + + for (; (i + 3) < n; i += 4) + { + x_vec[0] = _mm256_loadu_pd((double *)x0); + + x_vec[0] = _mm256_mul_pd(x_vec[0], alphav); + + _mm256_storeu_pd((double *)y0, x_vec[0]); + + x0 += 4; + y0 += 4; + } + + _mm256_zeroupper(); + } + + // Handling fringe case or non-unit strides + for (; i < n; i++) + { + *y0 = (*alpha) * (*x0); + x0 += incx; + y0 += incy; + } +} + +/* This kernels for cscal2v and zscal2v perform y := alpha * conjx(x) alpha = a + i(b) X = x + i(y) @@ -114,6 +524,269 @@ the behaviour is not defined. In this kernel, we return without performing any computation. */ +void bli_cscal2v_zen_int + ( + conj_t conjx, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, return early. + if (bli_zero_dim1(n)) + return; + + if (PASTEMAC(c, eq0)(*alpha)) + { + /* If alpha is zero, use setv. */ + scomplex *zero = PASTEMAC(c, 0); + + bli_csetv_zen_int + ( + BLIS_NO_CONJUGATE, + n, + zero, + y, incy, + cntx + ); + + return; + } + else if (PASTEMAC(c, eq1)(*alpha)) + { + /* If alpha is one, use copyv. */ + bli_ccopyv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + + return; + } + + // Setting the iterator and local pointers + dim_t i = 0; + scomplex *x0 = x; + scomplex *y0 = y; + + float real = (*alpha).real; + float imag = (*alpha).imag; + + if (bli_is_noconj(conjx)) + { + if (incx == 1 && incy == 1) + { + __m256 temp[8], alpha_real, alpha_imag, x_vec[4]; + + alpha_real = _mm256_set1_ps(real); + alpha_imag = _mm256_set1_ps(imag); + + const dim_t n_elem_per_reg = 4; + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + temp[2] = _mm256_mul_ps(x_vec[1], alpha_real); + temp[3] = _mm256_mul_ps(x_vec[1], alpha_imag); + temp[4] = _mm256_mul_ps(x_vec[2], alpha_real); + temp[5] = _mm256_mul_ps(x_vec[2], alpha_imag); + temp[6] = _mm256_mul_ps(x_vec[3], alpha_real); + temp[7] = _mm256_mul_ps(x_vec[3], alpha_imag); + + temp[1] = _mm256_permute_ps(temp[1], 0b10110001); + temp[3] = _mm256_permute_ps(temp[3], 0b10110001); + temp[5] = _mm256_permute_ps(temp[5], 0b10110001); + temp[7] = _mm256_permute_ps(temp[7], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[0], temp[1]); + temp[2] = _mm256_addsub_ps(temp[2], temp[3]); + temp[4] = _mm256_addsub_ps(temp[4], temp[5]); + temp[6] = _mm256_addsub_ps(temp[6], temp[7]); + + _mm256_storeu_ps((float *)y0, temp[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), temp[2]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), temp[4]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), temp[6]); + + x0 += 16; + y0 += 16; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + temp[2] = _mm256_mul_ps(x_vec[1], alpha_real); + temp[3] = _mm256_mul_ps(x_vec[1], alpha_imag); + + temp[1] = _mm256_permute_ps(temp[1], 0b10110001); + temp[3] = _mm256_permute_ps(temp[3], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[0], temp[1]); + temp[2] = _mm256_addsub_ps(temp[2], temp[3]); + + _mm256_storeu_ps((float *)y0, temp[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), temp[2]); + + x0 += 8; + y0 += 8; + } + + for (; (i + 3) < n; i += 4) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + + temp[1] = _mm256_permute_ps(temp[1], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[0], temp[1]); + + _mm256_storeu_ps((float *)y0, temp[0]); + + x0 += 4; + y0 += 4; + } + _mm256_zeroupper(); + } + + // Handling fringe cases or non-unit strides + for (; i < n; i++) + { + y0->real = real * ( x0->real ) - imag * ( x0->imag ); + y0->imag = imag * ( x0->real ) + real * ( x0->imag ); + + x0 += incx; + y0 += incy; + } + } + /* This else condition handles the computation + for conjugate X cases */ + else + { + if (incx == 1 && incy == 1) + { + __m256 temp[8], alpha_real, alpha_imag, x_vec[4]; + + alpha_real = _mm256_set1_ps(real); + alpha_imag = _mm256_set1_ps(imag); + + const dim_t n_elem_per_reg = 4; + + for (; (i + 15) < n; i += 16) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + x_vec[2] = _mm256_loadu_ps((float *)(x0 + 2 * n_elem_per_reg)); + x_vec[3] = _mm256_loadu_ps((float *)(x0 + 3 * n_elem_per_reg)); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + temp[2] = _mm256_mul_ps(x_vec[1], alpha_real); + temp[3] = _mm256_mul_ps(x_vec[1], alpha_imag); + temp[4] = _mm256_mul_ps(x_vec[2], alpha_real); + temp[5] = _mm256_mul_ps(x_vec[2], alpha_imag); + temp[6] = _mm256_mul_ps(x_vec[3], alpha_real); + temp[7] = _mm256_mul_ps(x_vec[3], alpha_imag); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + temp[2] = _mm256_permute_ps(temp[2], 0b10110001); + temp[4] = _mm256_permute_ps(temp[4], 0b10110001); + temp[6] = _mm256_permute_ps(temp[6], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[1], temp[0]); + temp[2] = _mm256_addsub_ps(temp[3], temp[2]); + temp[4] = _mm256_addsub_ps(temp[5], temp[4]); + temp[6] = _mm256_addsub_ps(temp[7], temp[6]); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + temp[2] = _mm256_permute_ps(temp[2], 0b10110001); + temp[4] = _mm256_permute_ps(temp[4], 0b10110001); + temp[6] = _mm256_permute_ps(temp[6], 0b10110001); + + _mm256_storeu_ps((float *)y0, temp[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), temp[2]); + _mm256_storeu_ps((float *)(y0 + 2 * n_elem_per_reg), temp[4]); + _mm256_storeu_ps((float *)(y0 + 3 * n_elem_per_reg), temp[6]); + + x0 += 16; + y0 += 16; + } + + for (; (i + 7) < n; i += 8) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + x_vec[1] = _mm256_loadu_ps((float *)(x0 + n_elem_per_reg)); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + temp[2] = _mm256_mul_ps(x_vec[1], alpha_real); + temp[3] = _mm256_mul_ps(x_vec[1], alpha_imag); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + temp[2] = _mm256_permute_ps(temp[2], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[1], temp[0]); + temp[2] = _mm256_addsub_ps(temp[3], temp[2]); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + temp[2] = _mm256_permute_ps(temp[2], 0b10110001); + + _mm256_storeu_ps((float *)y0, temp[0]); + _mm256_storeu_ps((float *)(y0 + n_elem_per_reg), temp[2]); + + x0 += 8; + y0 += 8; + } + + for (; (i + 3) < n; i += 4) + { + x_vec[0] = _mm256_loadu_ps((float *)x0); + + temp[0] = _mm256_mul_ps(x_vec[0], alpha_real); + temp[1] = _mm256_mul_ps(x_vec[0], alpha_imag); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + + temp[0] = _mm256_addsub_ps(temp[1], temp[0]); + + temp[0] = _mm256_permute_ps(temp[0], 0b10110001); + + _mm256_storeu_ps((float *)y0, temp[0]); + + x0 += 4; + y0 += 4; + } + + _mm256_zeroupper(); + } + + // Handling fringe cases or non-unit strides + for (; i < n; i++) + { + y0->real = real * ( x0->real ) + imag * ( x0->imag ); + y0->imag = imag * ( x0->real ) - real * ( x0->imag ); + + x0 += incx; + y0 += incy; + } + } +} void bli_zscal2v_zen_int ( @@ -127,9 +800,7 @@ void bli_zscal2v_zen_int { // If the vector dimension is zero, return early. - // When incx or incy is passed as zero or less than zero, - // the behaviour is not defined, so return early. - if (bli_zero_dim1(n)|| incx <= 0 || incy <=0) + if (bli_zero_dim1(n)) return; if (PASTEMAC(z, eq0)(*alpha)) @@ -137,15 +808,7 @@ void bli_zscal2v_zen_int /* If alpha is zero, use setv. */ dcomplex *zero = PASTEMAC(z, 0); - if(cntx == NULL) cntx = bli_gks_query_cntx(); - - /* Query the context for the kernel function pointer. */ - const num_t dt = PASTEMAC(z, type); - - PASTECH(z, setv_ker_ft) - setv_p = bli_cntx_get_l1v_ker_dt(dt, BLIS_SETV_KER, cntx); - - setv_p + bli_zsetv_zen_int ( BLIS_NO_CONJUGATE, n, diff --git a/kernels/zen/1/bli_scalv_zen_int.c b/kernels/zen/1/bli_scalv_zen_int.c index fa337c247f..34d6b161c5 100644 --- a/kernels/zen/1/bli_scalv_zen_int.c +++ b/kernels/zen/1/bli_scalv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -80,9 +80,11 @@ void bli_sscalv_zen_int if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; // If alpha is zero, use setv (in case y contains NaN or Inf). - if ( PASTEMAC(s,eq0)( *alpha ) ) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(s,eq0)( *alpha ) && n > 0 ) { float* zero = bli_s0; + if (cntx == NULL) cntx = bli_gks_query_cntx(); ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); f @@ -96,10 +98,12 @@ void bli_sscalv_zen_int return; } + dim_t n0 = bli_abs(n); + // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); - n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + n_viter = ( n0 ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n0 ) % ( n_elem_per_reg * n_iter_unroll ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override n_viter and n_left to use scalar code @@ -107,7 +111,7 @@ void bli_sscalv_zen_int if ( incx != 1 ) { n_viter = 0; - n_left = n; + n_left = n0; } // Initialize local pointers. @@ -178,10 +182,11 @@ void bli_dscalv_zen_int // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; - // If alpha is zero, use setv (in case y contains NaN or Inf). - if ( PASTEMAC(d,eq0)( *alpha ) ) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(d,eq0)( *alpha ) && n > 0 ) { double* zero = bli_d0; + if (cntx == NULL) cntx = bli_gks_query_cntx(); dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); f @@ -195,10 +200,12 @@ void bli_dscalv_zen_int return; } + dim_t n0 = bli_abs(n); + // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); - n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + n_viter = ( n0 ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n0 ) % ( n_elem_per_reg * n_iter_unroll ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override n_viter and n_left to use scalar code @@ -206,7 +213,7 @@ void bli_dscalv_zen_int if ( incx != 1 ) { n_viter = 0; - n_left = n; + n_left = n0; } // Initialize local pointers. diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index e760367060..ab5e46af04 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -60,8 +60,8 @@ void bli_sscalv_zen_int10 // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; - // If alpha is zero, use setv. - if ( PASTEMAC(s,eq0)( *alpha ) ) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(s,eq0)( *alpha ) && n > 0 ) { float* zero = bli_s0; if ( cntx == NULL ) cntx = bli_gks_query_cntx(); @@ -78,6 +78,8 @@ void bli_sscalv_zen_int10 return; } + dim_t n0 = bli_abs(n); + // Initialize local pointers. x0 = x; @@ -88,11 +90,11 @@ void bli_sscalv_zen_int10 dim_t option; // Unroll and the loop used is picked based on the input size. - if( n < 300) + if( n0 < 300) { option = 2; } - else if( n < 500) + else if( n0 < 500) { option = 1; } @@ -105,7 +107,7 @@ void bli_sscalv_zen_int10 { case 0: - for ( ; (i + 127) < n; i += 128 ) + for ( ; (i + 127) < n0; i += 128 ) { //Load the input values xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -175,7 +177,7 @@ void bli_sscalv_zen_int10 case 1 : - for ( ; (i + 95) < n; i += 96 ) + for ( ; (i + 95) < n0; i += 96 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); @@ -227,7 +229,7 @@ void bli_sscalv_zen_int10 case 2: - for ( ; (i + 47) < n; i += 48 ) + for ( ; (i + 47) < n0; i += 48 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); @@ -256,7 +258,7 @@ void bli_sscalv_zen_int10 x0 += 6*n_elem_per_reg; } - for ( ; (i + 23) < n; i += 24 ) + for ( ; (i + 23) < n0; i += 24 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); @@ -273,7 +275,7 @@ void bli_sscalv_zen_int10 x0 += 3*n_elem_per_reg; } - for ( ; (i + 7) < n; i += 8 ) + for ( ; (i + 7) < n0; i += 8 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -284,7 +286,7 @@ void bli_sscalv_zen_int10 x0 += 1*n_elem_per_reg; } - for ( ; (i + 0) < n; i += 1 ) + for ( ; (i + 0) < n0; i += 1 ) { *x0 *= *alpha; @@ -296,7 +298,7 @@ void bli_sscalv_zen_int10 { const float alphac = *alpha; - for ( ; i < n; ++i ) + for ( ; i < n0; ++i ) { *x0 *= alphac; @@ -307,7 +309,7 @@ void bli_sscalv_zen_int10 // ----------------------------------------------------------------------------- -void bli_dscalv_zen_int10 +BLIS_EXPORT_BLIS void bli_dscalv_zen_int10 ( conj_t conjalpha, dim_t n, @@ -329,8 +331,8 @@ void bli_dscalv_zen_int10 // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; - // If alpha is zero, use setv. - if ( PASTEMAC(d,eq0)( *alpha ) ) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(d,eq0)( *alpha ) && n > 0 ) { double* zero = bli_d0; if ( cntx == NULL ) cntx = bli_gks_query_cntx(); @@ -348,6 +350,8 @@ void bli_dscalv_zen_int10 return; } + dim_t n0 = bli_abs(n); + // Initialize local pointers. x0 = x; @@ -358,11 +362,11 @@ void bli_dscalv_zen_int10 dim_t option; // Unroll and the loop used is picked based on the input size. - if(n < 200) + if(n0 < 200) { option = 2; } - else if(n < 500) + else if(n0 < 500) { option = 1; } @@ -375,7 +379,7 @@ void bli_dscalv_zen_int10 { case 0: - for (; (i + 63) < n; i += 64 ) + for (; (i + 63) < n0; i += 64 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -440,7 +444,7 @@ void bli_dscalv_zen_int10 x0 += 16*n_elem_per_reg; } - for (; (i + 47) < n; i += 48 ) + for (; (i + 47) < n0; i += 48 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -492,7 +496,7 @@ void bli_dscalv_zen_int10 case 1: - for (; (i + 31) < n; i += 32 ) + for (; (i + 31) < n0; i += 32 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -529,7 +533,7 @@ void bli_dscalv_zen_int10 case 2: - for ( ; (i + 11) < n; i += 12 ) + for ( ; (i + 11) < n0; i += 12 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -546,7 +550,7 @@ void bli_dscalv_zen_int10 x0 += 3*n_elem_per_reg; } - for ( ; (i + 3) < n; i += 4 ) + for ( ; (i + 3) < n0; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -557,7 +561,7 @@ void bli_dscalv_zen_int10 x0 += 1*n_elem_per_reg; } - for ( ; (i + 0) < n; i += 1 ) + for ( ; (i + 0) < n0; i += 1 ) { *x0 *= *alpha; @@ -569,7 +573,7 @@ void bli_dscalv_zen_int10 { const double alphac = *alpha; - for ( ; i < n; ++i ) + for ( ; i < n0; ++i ) { *x0 *= alphac; @@ -587,6 +591,30 @@ void bli_zdscalv_zen_int10 cntx_t* restrict cntx ) { + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha )) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(z,eq0)( *alpha ) && n > 0 ) + { + // Expert interface of setv is invoked when alpha is zero + dcomplex *zero = bli_z0; + + /* When alpha is zero all the element in x are set to zero */ + PASTEMAC2(z, setv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx, + NULL); + + return; + } + + dim_t n0 = bli_abs(n); + dim_t i = 0; const dim_t n_elem_per_reg = 4; // number of elements per register @@ -607,7 +635,7 @@ void bli_zdscalv_zen_int10 alphav = _mm256_broadcast_sd( &alphac ); - for ( ; ( i + 29 ) < n; i += 30 ) + for ( ; ( i + 29 ) < n0; i += 30 ) { xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); @@ -660,7 +688,7 @@ void bli_zdscalv_zen_int10 x0 += 15 * n_elem_per_reg; } - for ( ; ( i + 23 ) < n; i += 24 ) + for ( ; ( i + 23 ) < n0; i += 24 ) { xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); @@ -704,7 +732,7 @@ void bli_zdscalv_zen_int10 x0 += 12 * n_elem_per_reg; } - for ( ; ( i + 15 ) < n; i += 16 ) + for ( ; ( i + 15 ) < n0; i += 16 ) { xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); @@ -736,7 +764,7 @@ void bli_zdscalv_zen_int10 x0 += 8 * n_elem_per_reg; } - for ( ; ( i + 7 ) < n; i += 8 ) + for ( ; ( i + 7 ) < n0; i += 8 ) { xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); @@ -756,7 +784,7 @@ void bli_zdscalv_zen_int10 x0 += 4 * n_elem_per_reg; } - for ( ; ( i + 3 ) < n; i += 4 ) + for ( ; ( i + 3 ) < n0; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 ); xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); @@ -770,7 +798,7 @@ void bli_zdscalv_zen_int10 x0 += 2 * n_elem_per_reg; } - for ( ; ( i + 1 ) < n; i += 2 ) + for ( ; ( i + 1 ) < n0; i += 2 ) { xv[0] = _mm256_loadu_pd( x0 ); @@ -795,7 +823,7 @@ void bli_zdscalv_zen_int10 alpha_reg = _mm_set1_pd((*alpha).real); - for (; i < n; ++i) + for (; i < n0; ++i) { x_vec = _mm_loadu_pd(x0); @@ -807,6 +835,173 @@ void bli_zdscalv_zen_int10 } } +void bli_cscalv_zen_int + ( + conj_t conjalpha, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(c,eq0)( *alpha ) && n > 0 ) + { + // Expert interface of setv is invoked when alpha is zero + scomplex *zero = bli_c0; + + /* When alpha is zero all the element in x are set to zero */ + PASTEMAC2(c, setv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx, + NULL); + + return; + } + + dim_t n0 = bli_abs(n); + + dim_t i = 0; + scomplex alpha_conj; + float *x0 = (float *)x; + + // Performs conjugation of alpha based on conjalpha + PASTEMAC(c, copycjs)(conjalpha, *alpha, alpha_conj) + + float real = alpha_conj.real; + float imag = alpha_conj.imag; + + // Handling computation for unit-strided vectors + if ( incx == 1 ) + { + dim_t const n_elem_per_reg = 8; + + __m256 alpha_real_ymm, alpha_imag_ymm; + + alpha_real_ymm = _mm256_broadcast_ss(&real); + alpha_imag_ymm = _mm256_broadcast_ss(&imag); + + __m256 x_vec_ymm[4], temp_ymm[8]; + + /* Code logic + + Consider, + x1= a1 + ib1, x2 = a1 + ib2 + alpha = p + iq + + Vector values + x_vec_ymm = a1, b1, a2, b2 + alpha_real_ymm = p, p, p, p + alpha_imag_ymm = q, q, q, q + + Computation + + All real values + temp_1 = x_vec_ymm * alpha_real_ymm = a1p, b1p, a2p, b2p + + All imaginary values + temp_2 = x_vec_ymm * alpha_imag_ymm = a1q, b1q, a2q, b2q + + permute temp_2 to get + + b1q, a1q, b2q, a2q + + addsub temp_1 and temp_2 to get the final result + and then store + */ + + for (; (i + 15) < n0; i += 16) + { + x_vec_ymm[0] = _mm256_loadu_ps(x0); + x_vec_ymm[1] = _mm256_loadu_ps(x0 + n_elem_per_reg); + x_vec_ymm[2] = _mm256_loadu_ps(x0 + 2 * n_elem_per_reg); + x_vec_ymm[3] = _mm256_loadu_ps(x0 + 3 * n_elem_per_reg); + + temp_ymm[0] = _mm256_mul_ps(x_vec_ymm[0], alpha_imag_ymm); + temp_ymm[1] = _mm256_mul_ps(x_vec_ymm[1], alpha_imag_ymm); + temp_ymm[2] = _mm256_mul_ps(x_vec_ymm[2], alpha_imag_ymm); + temp_ymm[3] = _mm256_mul_ps(x_vec_ymm[3], alpha_imag_ymm); + + temp_ymm[4] = _mm256_permute_ps(temp_ymm[0], 0xB1); + temp_ymm[5] = _mm256_permute_ps(temp_ymm[1], 0xB1); + temp_ymm[6] = _mm256_permute_ps(temp_ymm[2], 0xB1); + temp_ymm[7] = _mm256_permute_ps(temp_ymm[3], 0xB1); + + temp_ymm[0] = _mm256_fmaddsub_ps(x_vec_ymm[0], alpha_real_ymm, temp_ymm[4]); + temp_ymm[1] = _mm256_fmaddsub_ps(x_vec_ymm[1], alpha_real_ymm, temp_ymm[5]); + temp_ymm[2] = _mm256_fmaddsub_ps(x_vec_ymm[2], alpha_real_ymm, temp_ymm[6]); + temp_ymm[3] = _mm256_fmaddsub_ps(x_vec_ymm[3], alpha_real_ymm, temp_ymm[7]); + + _mm256_storeu_ps(x0, temp_ymm[0]); + _mm256_storeu_ps(x0 + n_elem_per_reg, temp_ymm[1]); + _mm256_storeu_ps(x0 + 2 * n_elem_per_reg, temp_ymm[2]); + _mm256_storeu_ps(x0 + 3 * n_elem_per_reg, temp_ymm[3]); + + x0 += 4 * n_elem_per_reg; + } + + for (; (i + 7) < n0; i += 8) + { + x_vec_ymm[0] = _mm256_loadu_ps(x0); + x_vec_ymm[1] = _mm256_loadu_ps(x0 + n_elem_per_reg); + + temp_ymm[0] = _mm256_mul_ps(x_vec_ymm[0], alpha_imag_ymm); + temp_ymm[1] = _mm256_mul_ps(x_vec_ymm[1], alpha_imag_ymm); + + temp_ymm[2] = _mm256_permute_ps(temp_ymm[0], 0xB1); + temp_ymm[3] = _mm256_permute_ps(temp_ymm[1], 0xB1); + + temp_ymm[0] = _mm256_fmaddsub_ps(x_vec_ymm[0], alpha_real_ymm, temp_ymm[2]); + temp_ymm[1] = _mm256_fmaddsub_ps(x_vec_ymm[1], alpha_real_ymm, temp_ymm[3]); + + _mm256_storeu_ps(x0, temp_ymm[0]); + _mm256_storeu_ps(x0 + n_elem_per_reg, temp_ymm[1]); + + x0 += 2 * n_elem_per_reg; + } + + for (; (i + 3) < n0; i += 4) + { + x_vec_ymm[0] = _mm256_loadu_ps(x0); + + temp_ymm[0] = _mm256_mul_ps(x_vec_ymm[0], alpha_imag_ymm); + + temp_ymm[1] = _mm256_permute_ps(temp_ymm[0], 0xB1); + + temp_ymm[0] = _mm256_fmaddsub_ps(x_vec_ymm[0], alpha_real_ymm, temp_ymm[1]); + + _mm256_storeu_ps(x0, temp_ymm[0]); + + x0 += n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur later, + // especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + + for (; i < n0; i++) + { + float x_real, x_imag; + x_real = real * (*x0) - imag * (*(x0 + 1)); + x_imag = real * (*(x0 + 1)) + imag * (*x0); + + *x0 = x_real; + *(x0 + 1) = x_imag; + + x0 += 2 * incx; + } +} + void bli_zscalv_zen_int ( conj_t conjalpha, @@ -816,24 +1011,14 @@ void bli_zscalv_zen_int cntx_t* restrict cntx ) { - /* - Undefined behaviour - ------------------- - - 1. This layer is not BLAS complaint and the kernel results in - undefined behaviour when n <= 0 and incx <= 1. The expectation - is that the application/higher-layer invoking this layer should - the arg checks. - */ - // if (bli_zero_dim1(n) || PASTEMAC(z, eq1)(*alpha)) - // return; + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha ) ) return; - // To Do: This call to SETV needs to be removed for BLAS compliance - // Currently removing this is resulting in ZHERK failures - if (PASTEMAC(z, eq0)(*alpha)) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(z,eq0)( *alpha ) && n > 0 ) { // Expert interface of setv is invoked when alpha is zero - dcomplex *zero = PASTEMAC(z, 0); + dcomplex *zero = bli_z0; /* When alpha is zero all the element in x are set to zero */ PASTEMAC2(z, setv, BLIS_TAPI_EX_SUF) @@ -848,6 +1033,8 @@ void bli_zscalv_zen_int return; } + dim_t n0 = bli_abs(n); + dim_t i = 0; dcomplex alpha_conj; double *x0 = (double *)x; @@ -858,8 +1045,8 @@ void bli_zscalv_zen_int double real = alpha_conj.real; double imag = alpha_conj.imag; - /*When incx is 1 and n >= 2 it is possible to use AVX2 instructions*/ - if (incx == 1 && n >= 2) + /*When incx is 1 and n0 >= 2 it is possible to use AVX2 instructions*/ + if (incx == 1 && n0 >= 2) { dim_t const n_elem_per_reg = 4; @@ -897,7 +1084,7 @@ void bli_zscalv_zen_int and then store */ - for (; (i + 7) < n; i += 8) + for (; (i + 7) < n0; i += 8) { x_vec_ymm[0] = _mm256_loadu_pd(x0); x_vec_ymm[1] = _mm256_loadu_pd(x0 + n_elem_per_reg); @@ -931,7 +1118,7 @@ void bli_zscalv_zen_int x0 += 4 * n_elem_per_reg; } - for (; (i + 3) < n; i += 4) + for (; (i + 3) < n0; i += 4) { x_vec_ymm[0] = _mm256_loadu_pd(x0); x_vec_ymm[1] = _mm256_loadu_pd(x0 + n_elem_per_reg); @@ -951,7 +1138,7 @@ void bli_zscalv_zen_int x0 += 2 * n_elem_per_reg; } - for (; (i + 1) < n; i += 2) + for (; (i + 1) < n0; i += 2) { x_vec_ymm[0] = _mm256_loadu_pd(x0); @@ -980,7 +1167,7 @@ void bli_zscalv_zen_int alpha_real_xmm = _mm_set1_pd(real); alpha_imag_xmm = _mm_set1_pd(imag); - for (; i < n; i++) + for (; i < n0; i++) { x_vec_xmm = _mm_loadu_pd(x0); diff --git a/kernels/zen/1/bli_setv_zen_int.c b/kernels/zen/1/bli_setv_zen_int.c index 5ebd061cdd..3468bafa1b 100644 --- a/kernels/zen/1/bli_setv_zen_int.c +++ b/kernels/zen/1/bli_setv_zen_int.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -46,90 +46,92 @@ void bli_ssetv_zen_int cntx_t* restrict cntx ) { - const dim_t num_elem_per_reg = 8; - dim_t i = 0; - __m256 alphav; - - // If the vector dimension is zero return early. - if ( bli_zero_dim1( n ) ) return; - - if ( incx == 1 ) - { - alphav = _mm256_broadcast_ss( alpha ); - - // For loop with n & ~0x7F => n & 0xFFFFFF80 masks the lower bits and results in multiples of 128 - // for example if n = 255 - // n & ~0x7F results in 128: copy from 0 to 128 happens in first loop - // n & ~0x3F results in 192: copy from 128 to 192 happens in second loop - // n & ~0x1F results in 224: copy from 128 to 192 happens in third loop and so on. - for ( i = 0; i < (n & (~0x7F)); i += 128 ) - { - _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 4, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 5, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 6, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 7, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 8, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 9, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 10, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 11, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 12, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 13, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 14, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 15, alphav); - - x += 128; - } - for ( ; i < (n & (~0x3F)); i += 64 ) - { - _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 4, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 5, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 6, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 7, alphav); - - x += 64; - } - for ( ; i < (n & (~0x1F)); i += 32 ) - { - _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 3, alphav); - - x += 32; - } - for ( ; i < (n & (~0x0F)); i += 16 ) - { - _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_ps(x + num_elem_per_reg * 1, alphav); - - x += 16; - } - for ( ; i < (n & (~0x07)); i += 8 ) - { - _mm256_storeu_ps(x + num_elem_per_reg * 0, alphav); - x += 8; - } - for ( ; i < n; ++i ) - { - *x++ = *alpha; - } - } - else - { - for ( dim_t i = 0; i < n; ++i ) - { - *x = *alpha; - x += incx; - } - } + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m256 alphav; + + float *x0 = x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 ) + { + alphav = _mm256_broadcast_ss( alpha ); + + // For loop with n & ~0x7F => n & 0xFFFFFF80 masks the lower bits and results in multiples of 128 + // for example if n = 255 + // n & ~0x7F results in 128: copy from 0 to 128 happens in first loop + // n & ~0x3F results in 192: copy from 128 to 192 happens in second loop + // n & ~0x1F results in 224: copy from 128 to 192 happens in third loop and so on. + for ( i = 0; i < (n & (~0x7F)); i += 128 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 8, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 9, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 10, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 11, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 12, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 13, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 14, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 15, alphav); + + x0 += 128; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + + x0 += 64; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + + x0 += 32; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + + x0 += 16; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + x0 += 8; + } + for ( ; i < n; ++i ) + { + *x0++ = *alpha; + } + } + else + { + for ( dim_t i = 0; i < n; ++i ) + { + *x0 = *alpha; + x0 += incx; + } + } } void bli_dsetv_zen_int @@ -141,88 +143,300 @@ void bli_dsetv_zen_int cntx_t* restrict cntx ) { - const dim_t num_elem_per_reg = 4; - dim_t i = 0; - __m256d alphav; - - // If the vector dimension is zero return early. - if ( bli_zero_dim1( n ) ) return; - - if ( incx == 1 ) - { - // Broadcast the alpha scalar to all elements of a vector register. - alphav = _mm256_broadcast_sd( alpha ); - - // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, - // the copy operation will be done for the multiples of 64 - for ( i = 0; i < (n & (~0x3F)); i += 64 ) - { - _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 4, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 5, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 6, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 7, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 8, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 9, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 10, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 11, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 12, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 13, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 14, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 15, alphav); - - x += num_elem_per_reg * 16; - } - for ( ; i < (n & (~0x1F)); i += 32 ) - { - _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 4, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 5, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 6, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 7, alphav); - - x += num_elem_per_reg * 8; - } - for ( ; i < (n & (~0xF)); i += 16 ) - { - _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 2, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 3, alphav); - - x += num_elem_per_reg * 4; - } - for ( ; i < (n & (~0x07)); i += 8 ) - { - _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); - _mm256_storeu_pd(x + num_elem_per_reg * 1, alphav); - - x += num_elem_per_reg * 2; - } - for ( ; i < (n & (~0x03)); i += 4 ) - { - _mm256_storeu_pd(x + num_elem_per_reg * 0, alphav); - x += num_elem_per_reg; - } - for ( ; i < n; ++i ) - { - *x++ = *alpha; - } - } - else - { - for ( i = 0; i < n; ++i ) - { - *x = *alpha; - - x += incx; - } - } + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + __m256d alphav; + + double *x0 = x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + alphav = _mm256_broadcast_sd( alpha ); + + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // the copy operation will be done for the multiples of 64 + for ( i = 0; i < (n & (~0x3F)); i += 64 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + x0 += num_elem_per_reg * 16; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + + x0 += num_elem_per_reg * 8; + } + for ( ; i < (n & (~0xF)); i += 16 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + + x0 += num_elem_per_reg * 4; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + + x0 += num_elem_per_reg * 2; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + x0 += num_elem_per_reg; + } + for ( ; i < n; ++i ) + { + *x0++ = *alpha; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + *x0 = *alpha; + + x0 += incx; + } + } +} + +void bli_csetv_zen_int + ( + conj_t conjalpha, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // Declaring and initializing local variables and pointers + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + float *x0 = (float *)x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + scomplex alpha_conj = *alpha; + + // Handle conjugation of alpha + if( bli_is_conj( conjalpha ) ) alpha_conj.imag = -alpha_conj.imag; + + if ( incx == 1 ) + { + __m256 alphaRv, alphaIv, alphav; + + // Broadcast the scomplex alpha value + alphaRv = _mm256_broadcast_ss( &(alpha_conj.real) ); + alphaIv = _mm256_broadcast_ss( &(alpha_conj.imag) ); + alphav = _mm256_unpacklo_ps( alphaRv, alphaIv ); + + // The condition n & ~0x3F => n & 0xFFFFFFC0 + // This sets the lower 6 bits to 0 and results in multiples of 64 + // Thus, we iterate in blocks of 64 scomplex elements + // Fringe loops have similar conditions to set their masks(32, 16, ...) + for ( i = 0; i < (n & (~0x3F)); i += 64 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 8, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 9, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 10, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 11, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 12, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 13, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 14, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 15, alphav); + + x0 += num_elem_per_reg * 16; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + + x0 += num_elem_per_reg * 8; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + + x0 += num_elem_per_reg * 4; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + + x0 += num_elem_per_reg * 2; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + _mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + x0 += num_elem_per_reg; + } + } + + // Code-section for non-unit stride + for( ; i < n; i += 1 ) + { + *x0 = alpha_conj.real; + *(x0 + 1) = alpha_conj.imag; + + x0 += 2 * incx; + } + +} + +void bli_zsetv_zen_int + ( + conj_t conjalpha, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // Declaring and initializing local variables and pointers + const dim_t num_elem_per_reg = 4; + dim_t i = 0; + double *x0 = (double *)x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + // Handle conjugation of alpha + if( bli_is_conj( conjalpha ) ) alpha->imag = -alpha->imag; + + if ( incx == 1 ) + { + __m256d alphav; + + // Broadcast the dcomplex alpha value + alphav = _mm256_broadcast_pd( (const __m128d *)alpha ); + + // The condition n & ~0x1F => n & 0xFFFFFFE0 + // This sets the lower 5 bits to 0 and results in multiples of 32 + // Thus, we iterate in blocks of 32 elements + // Fringe loops have similar conditions to set their masks(16, 8, ...) + for ( i = 0; i < (n & (~0x1F)); i += 32 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + x0 += num_elem_per_reg * 16; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + + x0 += num_elem_per_reg * 8; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + + x0 += num_elem_per_reg * 4; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + + x0 += num_elem_per_reg * 2; + } + for ( ; i < (n & (~0x01)); i += 2 ) + { + _mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + x0 += num_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur later, + // especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + + if ( i < n ) + { + __m128d alphav; + alphav = _mm_loadu_pd((const double*)alpha); + + for( ; i < n; i += 1 ) + { + _mm_storeu_pd(x0, alphav); + x0 += 2 * incx; + } + } + } diff --git a/kernels/zen/1/bli_swapv_zen_int8.c b/kernels/zen/1/bli_swapv_zen_int8.c index ba7c92593c..61c022a99a 100644 --- a/kernels/zen/1/bli_swapv_zen_int8.c +++ b/kernels/zen/1/bli_swapv_zen_int8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -202,7 +202,7 @@ void bli_sswapv_zen_int8 //-------------------------------------------------------------------------------- -void bli_dswapv_zen_int8 +BLIS_EXPORT_BLIS void bli_dswapv_zen_int8 ( dim_t n, double* restrict x, inc_t incx, diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index 9d0d42dd3d..5b3196376e 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -192,7 +192,7 @@ void bli_daxpy2v_zen_int * z := z + alphax * conjx(x) + alphay * conjy(y) * where, * x, y & z are double complex vectors of length n. - * alpha & beta are complex scalers. + * alpha & beta are complex scalars. */ void bli_zaxpy2v_zen_int ( diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index 43236887d9..1bc3de6572 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -440,10 +440,10 @@ void bli_zaxpyf_zen_int_4 // Prefetching the elements of A to the L1 cache. // These will be used even if SSE instructions are used - _mm_prefetch(a_ptr[0], _MM_HINT_T1); - _mm_prefetch(a_ptr[1], _MM_HINT_T1); - _mm_prefetch(a_ptr[2], _MM_HINT_T1); - _mm_prefetch(a_ptr[3], _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[0]), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[1]), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[2]), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[3]), _MM_HINT_T1); if (inca == 1 && incy == 1) { @@ -482,15 +482,15 @@ void bli_zaxpyf_zen_int_4 ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); - _mm_prefetch(a_ptr[0] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[1] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[2] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[3] + distance, _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[0] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[1] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[2] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[3] + distance), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); - _mm_prefetch(y0 + distance, _MM_HINT_T1); + _mm_prefetch((char const*)(y0 + distance), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); @@ -519,15 +519,15 @@ void bli_zaxpyf_zen_int_4 ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); - _mm_prefetch(a_ptr[0] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[1] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[2] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[3] + distance * 2, _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[0] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[1] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[2] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[3] + distance * 2), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); - _mm_prefetch(y0 + distance * 2, _MM_HINT_T1); + _mm_prefetch((char const*)(y0 + distance * 2), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); @@ -605,15 +605,15 @@ void bli_zaxpyf_zen_int_4 ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); - _mm_prefetch(a_ptr[0] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[1] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[2] + distance, _MM_HINT_T1); - _mm_prefetch(a_ptr[3] + distance, _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[0] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[1] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[2] + distance), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[3] + distance), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); - _mm_prefetch(y0 + distance, _MM_HINT_T1); + _mm_prefetch((char const*)(y0 + distance), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); @@ -641,15 +641,15 @@ void bli_zaxpyf_zen_int_4 ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); - _mm_prefetch(a_ptr[0] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[1] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[2] + distance * 2, _MM_HINT_T1); - _mm_prefetch(a_ptr[3] + distance * 2, _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[0] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[1] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[2] + distance * 2), _MM_HINT_T1); + _mm_prefetch((char const*)(a_ptr[3] + distance * 2), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm14.v, ymm4.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm14.v, ymm5.v, ymm13.v); - _mm_prefetch(y0 + distance * 2, _MM_HINT_T1); + _mm_prefetch((char const*)(y0 + distance * 2), _MM_HINT_T1); ymm12.v = _mm256_fmadd_pd(ymm15.v, ymm6.v, ymm12.v); ymm13.v = _mm256_fmadd_pd(ymm15.v, ymm7.v, ymm13.v); diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index 3da593cf74..ae1e613ccd 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -301,7 +301,7 @@ void bli_daxpyf_zen_int_8 operation as axpyv or perform the operation using axpyf kernels with lower fuse factor. */ - if ( b_n != fuse_fac ) + if ( b_n < fuse_fac ) { if (b_n >= 5) { @@ -399,6 +399,33 @@ void bli_daxpyf_zen_int_8 return; } + else if ( b_n > fuse_fac ) + { + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } // At this point, we know that b_n is exactly equal to the fusing factor. diff --git a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c index fbd354593c..91222c3245 100644 --- a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -782,7 +782,7 @@ void bli_zdotxaxpyf_zen_int_8 // Temporary rho buffer holds computed dot product result dcomplex rho[ 4 ]; - // chi? variables to hold scaled scaler values from x vector + // chi? variables to hold scaled scalar values from x vector dcomplex chi0; dcomplex chi1; dcomplex chi2; @@ -1189,7 +1189,7 @@ void bli_cdotxaxpyf_zen_int_8 // Temporary rho buffer holds computed dot product result scomplex rho[ 4 ]; - // chi? variables to hold scaled scaler values from x vector + // chi? variables to hold scaled scalar values from x vector scomplex chi0; scomplex chi1; scomplex chi2; diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index bb39992de8..3f31d483ec 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -463,7 +463,7 @@ void bli_ddotxf_zen_int_8 operation as dotxv or perform the operation using dotxf kernels with lower fuse factor. */ - if (b_n != fuse_fac) + if (b_n < fuse_fac) { if (b_n >= 4) { @@ -535,6 +535,27 @@ void bli_ddotxf_zen_int_8 } return; } + else if ( b_n > fuse_fac ) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } // At this point, we know that b_n is exactly equal to the fusing factor. // However, m may not be a multiple of the number of elements per vector. diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index 6970a7f62a..320b696561 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -565,7 +565,9 @@ void bli_multi_sgemv_4x2 // Calculate the total number of multithreaded iteration total_iteration = b_n / b_fuse; +#ifdef BLIS_ENABLE_OPENMP _Pragma( "omp parallel for num_threads(n_threads)" ) +#endif for (dim_t j = 0; j < total_iteration; j++) { float *A1 = a + (b_fuse * j) * lda; diff --git a/kernels/zen/2/bli_gemv_zen_ref.c b/kernels/zen/2/bli_gemv_zen_ref.c index 0e53a5240f..5da6a332af 100644 --- a/kernels/zen/2/bli_gemv_zen_ref.c +++ b/kernels/zen/2/bli_gemv_zen_ref.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -120,3 +120,119 @@ void bli_dgemv_zen_ref_c } return; } + +/** + * bli_dgemv_zen_ref( ... ) + * This reference kernel for DGEMV supports row/colum storage schemes for both + * transpose and no-transpose cases. + */ +void bli_dgemv_zen_ref + ( + trans_t transa, + dim_t m, + dim_t n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + dim_t m0 = m; + dim_t n0 = n; + dim_t leny = m0; // Initializing length of y vector. + + double* a0 = (double*) a; + double* x0 = (double*) x; + double* y0 = (double*) y; + + if ( bli_is_trans( transa ) || bli_is_conjtrans( transa ) ) + { + // Updating length of y matrix if transpose is enabled. + leny = n0; + } + + // Perform y := beta * y + if ( !bli_deq1(*beta) ) // beta != 1 + { + if ( bli_deq0(*beta) ) // beta == 0 + { + for ( dim_t i = 0; i < leny; ++i ) + { + PASTEMAC(d,sets)( 0.0, 0.0, *(y0 + i*incy)) + } + } + else // beta != 0 + { + for ( dim_t i = 0; i < leny; ++i ) + { + PASTEMAC(d,scals)( *beta, *(y0 + i*incy) ) + } + } + } + + // If alpha == 0, return. + if ( bli_deq0( *alpha ) ) return; + + if ( bli_is_notrans( transa ) ) // BLIS_NO_TRANSPOSE + { + if ( incy == 1 ) + { + for ( dim_t i = 0; i < n0; ++i ) + { + double rho = (*alpha) * (*x0); + for ( dim_t j = 0; j < m0; ++j ) + { + *(y0 + j) += rho * (*(a0 + j)); + } + x0 += incx; + a0 += lda; + } + } + else // if ( incy != 1 ) + { + for ( dim_t i = 0; i < n0; ++i ) + { + double rho = (*alpha) * (*x0); + for ( dim_t j = 0; j < m0; ++j ) + { + *(y0 + j*incy) += rho * (*(a0 + j)); + } + x0 += incx; + a0 += lda; + } + } + } + else // BLIS_TRANSPOSE + { + if ( incx == 1 ) + { + for ( dim_t i = 0; i < n0; ++i ) + { + double rho = 0.0; + for ( dim_t j = 0; j < m0; ++j ) + { + rho += (*(a0 + j)) * (*(x0 + j)); + } + (*y0) += (*alpha) * rho; + y0 += incy; + a0 += lda; + } + } + else // if ( incx != 1 ) + { + for ( dim_t i = 0; i < n0; ++i ) + { + double rho = 0.0; + for ( dim_t j = 0; j < m0; ++j ) + { + rho += (*(a0 + j)) * (*(x0 + j*incx)); + } + (*y0) += (*alpha) * rho; + y0 += incy; + a0 += lda; + } + } + } +} diff --git a/kernels/zen/3/bli_gemm_tiny.c b/kernels/zen/3/bli_gemm_tiny.c index bf6ffa5cc2..32e42490be 100644 --- a/kernels/zen/3/bli_gemm_tiny.c +++ b/kernels/zen/3/bli_gemm_tiny.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,7 +47,7 @@ static dgemmsup_ker_ft kern_fp[] = bli_dgemmsup_rv_haswell_asm_6x8n }; -#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) +#if defined(BLIS_FAMILY_ZEN5) || defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) static err_t bli_dgemm_tiny_24x8_kernel ( conj_t conja, @@ -514,67 +514,15 @@ err_t bli_dgemm_tiny double* c, const inc_t rs_c0, const inc_t cs_c0 ) { - arch_t arch_id = get_arch_id(); - //for the below tiny sizes of matrix, we force it to be ST compute. - if( - m <= 24 && n <= 24 && k <= 20 && - (BLIS_ARCH_ZEN == arch_id || - BLIS_ARCH_ZEN2 == arch_id || - BLIS_ARCH_ZEN3 == arch_id || - BLIS_ARCH_ZEN4 == arch_id) - ) - { - bool ret = bli_aocl_enable_instruction_query(); - if((ret == FALSE) || - (arch_id != BLIS_ARCH_ZEN4) - ) - { - return bli_dgemm_tiny_6x8_kernel - ( - 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), - 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), - transa, - transb, - m, - n, - k, - alpha, - a, rs_a0, cs_a0, - b, rs_b0, cs_b0, - beta, - c, rs_c0, cs_c0 - ); - } -#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) - else if(BLIS_ARCH_ZEN4 == arch_id) - { - return bli_dgemm_tiny_24x8_kernel - ( - 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), - 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), - transa, - transb, - m, - n, - k, - alpha, - a, rs_a0, cs_a0, - b, rs_b0, cs_b0, - beta, - c, rs_c0, cs_c0 - ); - } -#endif - } - if(FALSE == bli_thread_get_is_parallel()) + // Query the architecture ID + arch_t id = bli_arch_query_id(); + + // Pick the kernel based on the architecture ID + switch (id) { - if( - BLIS_ARCH_ZEN == arch_id || - BLIS_ARCH_ZEN2 == arch_id || - BLIS_ARCH_ZEN3 == arch_id - ) - { - if( ( (m <= 8) || ( (m <= 1000) && (n <= 24) && (k >= 4) ) ) && (k <= 1500) ) + case BLIS_ARCH_ZEN5: + if(m<24 && ((n<=24 && k<=20) || + (n<=50 && ((m<=4 && k<=50) || (m!=8 && m!=9 && m!=16 && k<=10))))) { return bli_dgemm_tiny_6x8_kernel ( @@ -592,16 +540,14 @@ err_t bli_dgemm_tiny c, rs_c0, cs_c0 ); } - } -#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) - else if(BLIS_ARCH_ZEN4 == arch_id) - { - if(((m == n) && (m < 400) && (k < 1000)) || - ( (m != n) && (( ((m + n -k) < 1500) && - ((m + k-n) < 1500) && ((n + k-m) < 1500) ) || - ((n <= 100) && (k <=100))))) + break; + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + if(m <= 24 && n <= 24 && k <= 20) { - return bli_dgemm_tiny_24x8_kernel + return bli_dgemm_tiny_6x8_kernel ( 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), @@ -617,11 +563,67 @@ err_t bli_dgemm_tiny c, rs_c0, cs_c0 ); } - } -#endif - else + break; + default: + return BLIS_FAILURE; + } + + if(FALSE == bli_thread_get_is_parallel()) + { + // Pick the kernel based on the architecture ID + switch (id) { - ;//Return failure + case BLIS_ARCH_ZEN5: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_FAMILY_ZEN5) || defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) + if(((m == n) && (m < 400) && (k < 1000)) || + ( (m != n) && (( ((m + n -k) < 1500) && + ((m + k-n) < 1500) && ((n + k-m) < 1500) ) || + ((n <= 100) && (k <=100))))) + { + return bli_dgemm_tiny_24x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } +#endif + break; + + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + if( ( (m <= 8) || ( (m <= 1000) && (n <= 24) && (k >= 4) ) ) && (k <= 1500) ) + { + return bli_dgemm_tiny_6x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } + break; + default: + return BLIS_FAILURE; } } diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 0fd06c86f5..affd8ce147 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -5123,7 +5123,10 @@ err_t bli_trsm_small switch(dt) { case BLIS_DOUBLE: + case BLIS_DCOMPLEX: { + // threshold checks for these datatypes is + // done at bla layer break; } case BLIS_FLOAT: @@ -5134,13 +5137,6 @@ err_t bli_trsm_small } break; } - case BLIS_DCOMPLEX: - { - if((!is_parallel) && (m > 500 || n > 500)) { - return BLIS_NOT_YET_IMPLEMENTED; - } - break; - } default: { return BLIS_NOT_YET_IMPLEMENTED; @@ -12921,7 +12917,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); + ymm0 = _mm256_broadcast_sd((double const *)b11 + 2); + xmm5 = _mm_loadu_pd((double *)(b11)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) @@ -40221,8 +40220,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10));\ - ymm0 = _mm256_permute_ps(ymm0, 0x44);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ @@ -40246,10 +40245,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element else {\ for(k = 0; k< k_iter; k++) \ { \ - ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10 + 2));\ - ymm0 = _mm256_permute_ps(ymm0, 0x44);\ - xmm5 = _mm_loadu_ps((float const *)(b10));\ - ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ diff --git a/kernels/zen/3/bli_zgemm_avx2_k1.c b/kernels/zen/3/bli_zgemm_avx2_k1.c index 669afcfcfe..dfb45e812f 100644 --- a/kernels/zen/3/bli_zgemm_avx2_k1.c +++ b/kernels/zen/3/bli_zgemm_avx2_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,16 +90,16 @@ with k == 1. It expects the inputs and output to support the column-major storage scheme, without any requirement to conjugate/transpose any of the operands. */ -void bli_zgemm_4x4_avx2_k1_nn -( - dim_t m, - dim_t n, - dim_t k, - dcomplex* alpha, - dcomplex* a, const inc_t lda, - dcomplex* b, const inc_t ldb, - dcomplex* beta, - dcomplex* c, const inc_t ldc +err_t bli_zgemm_4x4_avx2_k1_nn + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc ) { // Setting the required variables for choosing the right path @@ -1123,7 +1123,7 @@ void bli_zgemm_4x4_avx2_k1_nn temp_cij += Z_MR; temp_ai += Z_MR; } - } + return BLIS_SUCCESS; } diff --git a/kernels/zen/3/bli_zgemm_zen_2x6.c b/kernels/zen/3/bli_zgemm_zen_2x6.c index e29537bda8..f846fb03b4 100644 --- a/kernels/zen/3/bli_zgemm_zen_2x6.c +++ b/kernels/zen/3/bli_zgemm_zen_2x6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen/3/bli_zgemmtrsm_l_2x6.c b/kernels/zen/3/bli_zgemmtrsm_l_2x6.c index 4a8d7c1b1d..2841b82cb0 100644 --- a/kernels/zen/3/bli_zgemmtrsm_l_2x6.c +++ b/kernels/zen/3/bli_zgemmtrsm_l_2x6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen/3/bli_zgemmtrsm_u_2x6.c b/kernels/zen/3/bli_zgemmtrsm_u_2x6.c index 12b5a61d99..a66e8bb91e 100644 --- a/kernels/zen/3/bli_zgemmtrsm_u_2x6.c +++ b/kernels/zen/3/bli_zgemmtrsm_u_2x6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c index a5dafcfcc3..ad87a1f817 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c @@ -1,10 +1,11 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -29,6 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c index 6597742a9d..1a2c57da1f 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -16,6 +17,7 @@ - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -27,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" #define BLIS_ASM_SYNTAX_ATT diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index 8911e97d2c..56bfb865d6 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -1,4 +1,3 @@ - /* BLIS @@ -6,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index 804e196e12..ef9b0151ea 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -194,7 +194,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m { if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; - else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; } if(beta->imag == 0.0)// (beta is real) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index 4e90b444d5..60b92b49f9 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -132,7 +132,6 @@ void bli_zgemmsup_rv_zen_asm_3x4n { if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; - else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; } if(beta->imag == 0.0)// (beta is real) diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c index 2e2b888f08..b50d68cc2e 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -1,4 +1,3 @@ - /* BLIS @@ -6,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c index c0c4d5f198..fa4a4d7bd1 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c @@ -1,10 +1,11 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -17,6 +18,7 @@ - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -28,6 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" #define BLIS_ASM_SYNTAX_ATT diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c index 3b93fc6802..acd644ffed 100644 --- a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -13,7 +13,7 @@ notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the - documentation and/or other materia provided with the distribution. + documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c index 55de26c884..9fafe653f4 100644 --- a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -13,7 +13,7 @@ notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the - documentation and/or other materia provided with the distribution. + documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c index 74c1c51989..bdcd372c4b 100644 --- a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -13,7 +13,7 @@ notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the - documentation and/or other materia provided with the distribution. + documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 45817f08be..cec27dffb1 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,9 +38,15 @@ // -- level-1v -- +// amaxv (intrinsics) +ADDV_KER_PROT( float, s, addv_zen_int ) +ADDV_KER_PROT( double, d, addv_zen_int ) +ADDV_KER_PROT( scomplex, c, addv_zen_int ) +ADDV_KER_PROT( dcomplex, z, addv_zen_int ) + // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int ) -AMAXV_KER_PROT( double, d, amaxv_zen_int ) +BLIS_EXPORT_BLIS AMAXV_KER_PROT( double, d, amaxv_zen_int ) // axpbyv (intrinsics) AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) @@ -58,7 +64,7 @@ AXPYV_KER_PROT( double, d, axpyv_zen_int ) // axpyv (intrinsics unrolled x10) AXPYV_KER_PROT( float, s, axpyv_zen_int10 ) -AXPYV_KER_PROT( double, d, axpyv_zen_int10 ) +BLIS_EXPORT_BLIS AXPYV_KER_PROT( double, d, axpyv_zen_int10 ) AXPYV_KER_PROT( scomplex, c, axpyv_zen_int5 ) AXPYV_KER_PROT( dcomplex, z, axpyv_zen_int5 ) @@ -81,28 +87,35 @@ DOTXV_KER_PROT( scomplex, c, dotxv_zen_int ) // scalv (intrinsics) SCALV_KER_PROT( float, s, scalv_zen_int ) SCALV_KER_PROT( double, d, scalv_zen_int ) +SCALV_KER_PROT( scomplex, c, scalv_zen_int ) SCALV_KER_PROT( dcomplex, z, scalv_zen_int ) // scalv (intrinsics unrolled x10) SCALV_KER_PROT( float, s, scalv_zen_int10 ) -SCALV_KER_PROT( double, d, scalv_zen_int10 ) +BLIS_EXPORT_BLIS SCALV_KER_PROT( double, d, scalv_zen_int10 ) SCALV_KER_PROT( dcomplex, z, dscalv_zen_int10 ) // swapv (intrinsics) SWAPV_KER_PROT(float, s, swapv_zen_int8 ) -SWAPV_KER_PROT(double, d, swapv_zen_int8 ) +BLIS_EXPORT_BLIS SWAPV_KER_PROT(double, d, swapv_zen_int8 ) // copyv (intrinsics) COPYV_KER_PROT( float, s, copyv_zen_int ) COPYV_KER_PROT( double, d, copyv_zen_int ) +COPYV_KER_PROT( scomplex, c, copyv_zen_int ) COPYV_KER_PROT( dcomplex, z, copyv_zen_int ) // scal2v (intrinsics) +SCAL2V_KER_PROT(float, s, scal2v_zen_int) +SCAL2V_KER_PROT(double, d, scal2v_zen_int) +SCAL2V_KER_PROT(scomplex, c, scal2v_zen_int) SCAL2V_KER_PROT(dcomplex, z, scal2v_zen_int) // setv (intrinsics) -SETV_KER_PROT(float, s, setv_zen_int) -SETV_KER_PROT(double, d, setv_zen_int) +SETV_KER_PROT( float, s, setv_zen_int) +SETV_KER_PROT( double, d, setv_zen_int) +SETV_KER_PROT( scomplex, c, setv_zen_int) +SETV_KER_PROT( dcomplex, z, setv_zen_int) // -- level-1f -- @@ -357,7 +370,7 @@ err_t bli_dgemm_8x6_avx2_k1_nn double* c, const inc_t ldc ); -void bli_zgemm_4x4_avx2_k1_nn +err_t bli_zgemm_4x4_avx2_k1_nn ( dim_t m, dim_t n, @@ -473,3 +486,16 @@ GEMM_UKR_PROT( dcomplex, z, gemm_zen_asm_2x6) GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_l_zen_asm_2x6) GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_u_zen_asm_2x6) + +void bli_dgemv_zen_ref + ( + trans_t transa, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ); diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c index ae0862d6a7..9f1d4fe3ac 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16) &&POST_OPS_RELU_SCALE_5x16F, &&POST_OPS_GELU_TANH_5x16F, &&POST_OPS_GELU_ERF_5x16F, - &&POST_OPS_CLIP_5x16F + &&POST_OPS_CLIP_5x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x16F, + &&POST_OPS_SWISH_5x16F, + &&POST_OPS_MATRIX_MUL_5x16F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -415,6 +419,89 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,3,10,11); + + // c[4:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,4,12,13); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,3,10,11); + + // c[4:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,4,12,13); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,8-15] + SWISH_F32_AVX2_DEF(ymm7, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,8-15] + SWISH_F32_AVX2_DEF(ymm9, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,8-15] + SWISH_F32_AVX2_DEF(ymm11, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,0-7] + SWISH_F32_AVX2_DEF(ymm12, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,8-15] + SWISH_F32_AVX2_DEF(ymm13, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x16F_DISABLE: ; @@ -444,7 +531,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16) &&POST_OPS_RELU_SCALE_4x16F, &&POST_OPS_GELU_TANH_4x16F, &&POST_OPS_GELU_ERF_4x16F, - &&POST_OPS_CLIP_4x16F + &&POST_OPS_CLIP_4x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x16F, + &&POST_OPS_SWISH_4x16F, + &&POST_OPS_MATRIX_MUL_4x16F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -754,6 +845,77 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,3,10,11); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,3,10,11); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,8-15] + SWISH_F32_AVX2_DEF(ymm7, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,8-15] + SWISH_F32_AVX2_DEF(ymm9, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,8-15] + SWISH_F32_AVX2_DEF(ymm11, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x16F_DISABLE: ; @@ -780,7 +942,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16) &&POST_OPS_RELU_SCALE_3x16F, &&POST_OPS_GELU_TANH_3x16F, &&POST_OPS_GELU_ERF_3x16F, - &&POST_OPS_CLIP_3x16F + &&POST_OPS_CLIP_3x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x16F, + &&POST_OPS_SWISH_3x16F, + &&POST_OPS_MATRIX_MUL_3x16F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1039,6 +1205,65 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,2,8,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,2,8,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,8-15] + SWISH_F32_AVX2_DEF(ymm7, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,8-15] + SWISH_F32_AVX2_DEF(ymm9, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x16F_DISABLE: ; @@ -1062,7 +1287,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x16) &&POST_OPS_RELU_SCALE_2x16F, &&POST_OPS_GELU_TANH_2x16F, &&POST_OPS_GELU_ERF_2x16F, - &&POST_OPS_CLIP_2x16F + &&POST_OPS_CLIP_2x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x16F, + &&POST_OPS_SWISH_2x16F, + &&POST_OPS_MATRIX_MUL_2x16F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1265,6 +1494,53 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,1,6,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,1,6,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,8-15] + SWISH_F32_AVX2_DEF(ymm7, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x16F_DISABLE: ; @@ -1285,7 +1561,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16) &&POST_OPS_RELU_SCALE_1x16F, &&POST_OPS_GELU_TANH_1x16F, &&POST_OPS_GELU_ERF_1x16F, - &&POST_OPS_CLIP_1x16F + &&POST_OPS_CLIP_1x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x16F, + &&POST_OPS_SWISH_1x16F, + &&POST_OPS_MATRIX_MUL_1x16F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1437,6 +1717,41 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x16F_DISABLE: ; @@ -1454,7 +1769,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x8) &&POST_OPS_RELU_SCALE_5x8F, &&POST_OPS_GELU_TANH_5x8F, &&POST_OPS_GELU_ERF_5x8F, - &&POST_OPS_CLIP_5x8F + &&POST_OPS_CLIP_5x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x8F, + &&POST_OPS_SWISH_5x8F, + &&POST_OPS_MATRIX_MUL_5x8F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1699,6 +2018,74 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x8) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,3,10); + + // c[4:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,4,12); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,3,10); + + // c[4:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,4,12); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,0-7] + SWISH_F32_AVX2_DEF(ymm12, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x8F_DISABLE: ; @@ -1723,7 +2110,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x8) &&POST_OPS_RELU_SCALE_4x8F, &&POST_OPS_GELU_TANH_4x8F, &&POST_OPS_GELU_ERF_4x8F, - &&POST_OPS_CLIP_4x8F + &&POST_OPS_CLIP_4x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x8F, + &&POST_OPS_SWISH_4x8F, + &&POST_OPS_MATRIX_MUL_4x8F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1935,6 +2326,65 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x8) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,3,10); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,3,10); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x8F_DISABLE: ; @@ -1957,7 +2407,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x8) &&POST_OPS_RELU_SCALE_3x8F, &&POST_OPS_GELU_TANH_3x8F, &&POST_OPS_GELU_ERF_3x8F, - &&POST_OPS_CLIP_3x8F + &&POST_OPS_CLIP_3x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x8F, + &&POST_OPS_SWISH_3x8F, + &&POST_OPS_MATRIX_MUL_3x8F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2140,6 +2594,56 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x8) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,2,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,2,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x8F_DISABLE: ; @@ -2160,7 +2664,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x8) &&POST_OPS_RELU_SCALE_2x8F, &&POST_OPS_GELU_TANH_2x8F, &&POST_OPS_GELU_ERF_2x8F, - &&POST_OPS_CLIP_2x8F + &&POST_OPS_CLIP_2x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x8F, + &&POST_OPS_SWISH_2x8F, + &&POST_OPS_MATRIX_MUL_2x8F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2315,6 +2823,47 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x8) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,1,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,1,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x8F_DISABLE: ; @@ -2333,7 +2882,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x8) &&POST_OPS_RELU_SCALE_1x8F, &&POST_OPS_GELU_TANH_1x8F, &&POST_OPS_GELU_ERF_1x8F, - &&POST_OPS_CLIP_1x8F + &&POST_OPS_CLIP_1x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x8F, + &&POST_OPS_SWISH_1x8F, + &&POST_OPS_MATRIX_MUL_1x8F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2455,6 +3008,38 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x8) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x8F_DISABLE: ; @@ -2471,7 +3056,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x4) &&POST_OPS_RELU_SCALE_5x4F, &&POST_OPS_GELU_TANH_5x4F, &&POST_OPS_GELU_ERF_5x4F, - &&POST_OPS_CLIP_5x4F + &&POST_OPS_CLIP_5x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x4F, + &&POST_OPS_SWISH_5x4F, + &&POST_OPS_MATRIX_MUL_5x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2714,6 +3303,74 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x4) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,2,6); + + // c[3:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,3,7); + + // c[4:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,2,6); + + // c[3:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,3,7); + + // c[4:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x4F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-3] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-3] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-3] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-3] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x4F_DISABLE: ; @@ -2738,7 +3395,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x4) &&POST_OPS_RELU_SCALE_4x4F, &&POST_OPS_GELU_TANH_4x4F, &&POST_OPS_GELU_ERF_4x4F, - &&POST_OPS_CLIP_4x4F + &&POST_OPS_CLIP_4x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x4F, + &&POST_OPS_SWISH_4x4F, + &&POST_OPS_MATRIX_MUL_4x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2922,32 +3583,91 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x4) // c[1,0-3] GELU_ERF_F32S_SSE(xmm5, xmm0, xmm1, xmm2) - // c[2,0-3] - GELU_ERF_F32S_SSE(xmm6, xmm0, xmm1, xmm2) + // c[2,0-3] + GELU_ERF_F32S_SSE(xmm6, xmm0, xmm1, xmm2) + + // c[3,0-3] + GELU_ERF_F32S_SSE(xmm7, xmm0, xmm1, xmm2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x4F: + { + xmm0 = _mm_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + xmm1 = _mm_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0,0-3] + CLIP_F32S_SSE(xmm4, xmm0, xmm1) + + // c[1,0-3] + CLIP_F32S_SSE(xmm5, xmm0, xmm1) + + // c[2,0-3] + CLIP_F32S_SSE(xmm6, xmm0, xmm1) + + // c[3,0-3] + CLIP_F32S_SSE(xmm7, xmm0, xmm1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,2,6); + + // c[3:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,1,5); - // c[3,0-3] - GELU_ERF_F32S_SSE(xmm7, xmm0, xmm1, xmm2) + // c[2:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,2,6); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_CLIP_4x4F: + // c[3:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x4F: { - xmm0 = _mm_set1_ps( *( float* )post_ops_list_temp->op_args2 ); - xmm1 = _mm_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; - // c[0,0-3] - CLIP_F32S_SSE(xmm4, xmm0, xmm1) + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) - // c[1,0-3] - CLIP_F32S_SSE(xmm5, xmm0, xmm1) + // c[1,0-3] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) - // c[2,0-3] - CLIP_F32S_SSE(xmm6, xmm0, xmm1) + // c[2,0-3] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) - // c[3,0-3] - CLIP_F32S_SSE(xmm7, xmm0, xmm1) + // c[3,0-3] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x4F_DISABLE: ; @@ -2971,7 +3691,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x4) &&POST_OPS_RELU_SCALE_3x4F, &&POST_OPS_GELU_TANH_3x4F, &&POST_OPS_GELU_ERF_3x4F, - &&POST_OPS_CLIP_3x4F + &&POST_OPS_CLIP_3x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x4F, + &&POST_OPS_SWISH_3x4F, + &&POST_OPS_MATRIX_MUL_3x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3151,6 +3875,56 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x4) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x4F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-3] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-3] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x4F_DISABLE: ; @@ -3171,7 +3945,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x4) &&POST_OPS_RELU_SCALE_2x4F, &&POST_OPS_GELU_TANH_2x4F, &&POST_OPS_GELU_ERF_2x4F, - &&POST_OPS_CLIP_2x4F + &&POST_OPS_CLIP_2x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x4F, + &&POST_OPS_SWISH_2x4F, + &&POST_OPS_MATRIX_MUL_2x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3325,6 +4103,47 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x4) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x4F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-3] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x4F_DISABLE: ; @@ -3343,7 +4162,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x4) &&POST_OPS_RELU_SCALE_1x4F, &&POST_OPS_GELU_TANH_1x4F, &&POST_OPS_GELU_ERF_1x4F, - &&POST_OPS_CLIP_1x4F + &&POST_OPS_CLIP_1x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x4F, + &&POST_OPS_SWISH_1x4F, + &&POST_OPS_MATRIX_MUL_1x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3462,6 +4285,38 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x4) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x4F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x4F_DISABLE: ; @@ -3478,7 +4333,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2) &&POST_OPS_RELU_SCALE_5x2F, &&POST_OPS_GELU_TANH_5x2F, &&POST_OPS_GELU_ERF_5x2F, - &&POST_OPS_CLIP_5x2F + &&POST_OPS_CLIP_5x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x2F, + &&POST_OPS_SWISH_5x2F, + &&POST_OPS_MATRIX_MUL_5x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3559,8 +4418,9 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float * )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -3721,6 +4581,74 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,3,7); + + // c[4:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,3,7); + + // c[4:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-1] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-1] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-1] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-1] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x2F_DISABLE: ; @@ -3745,7 +4673,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2) &&POST_OPS_RELU_SCALE_4x2F, &&POST_OPS_GELU_TANH_4x2F, &&POST_OPS_GELU_ERF_4x2F, - &&POST_OPS_CLIP_4x2F + &&POST_OPS_CLIP_4x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x2F, + &&POST_OPS_SWISH_4x2F, + &&POST_OPS_MATRIX_MUL_4x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3818,8 +4750,9 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -3956,6 +4889,65 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-1] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-1] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-1] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x2F_DISABLE: ; @@ -3978,7 +4970,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2) &&POST_OPS_RELU_SCALE_3x2F, &&POST_OPS_GELU_TANH_3x2F, &&POST_OPS_GELU_ERF_3x2F, - &&POST_OPS_CLIP_3x2F + &&POST_OPS_CLIP_3x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x2F, + &&POST_OPS_SWISH_3x2F, + &&POST_OPS_MATRIX_MUL_3x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4043,8 +5039,9 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd( (const double *) + ((float *) post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4158,6 +5155,56 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-1] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-1] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x2F_DISABLE: ; @@ -4178,7 +5225,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2) &&POST_OPS_RELU_SCALE_2x2F, &&POST_OPS_GELU_TANH_2x2F, &&POST_OPS_GELU_ERF_2x2F, - &&POST_OPS_CLIP_2x2F + &&POST_OPS_CLIP_2x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x2F, + &&POST_OPS_SWISH_2x2F, + &&POST_OPS_MATRIX_MUL_2x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4240,8 +5291,9 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4332,6 +5384,47 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-1] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x2F_DISABLE: ; @@ -4350,7 +5443,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2) &&POST_OPS_RELU_SCALE_1x2F, &&POST_OPS_GELU_TANH_1x2F, &&POST_OPS_GELU_ERF_1x2F, - &&POST_OPS_CLIP_1x2F + &&POST_OPS_CLIP_1x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x2F, + &&POST_OPS_SWISH_1x2F, + &&POST_OPS_MATRIX_MUL_1x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4400,8 +5497,9 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float*)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4469,6 +5567,38 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x2F_DISABLE: ; @@ -4485,7 +5615,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1) &&POST_OPS_RELU_SCALE_5x1F, &&POST_OPS_GELU_TANH_5x1F, &&POST_OPS_GELU_ERF_5x1F, - &&POST_OPS_CLIP_5x1F + &&POST_OPS_CLIP_5x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x1F, + &&POST_OPS_SWISH_5x1F, + &&POST_OPS_MATRIX_MUL_5x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4566,7 +5700,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -4728,6 +5862,74 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,3,7); + + // c[4:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,3,7); + + // c[4:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,4,8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-0] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-0] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-0] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-0] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x1F_DISABLE: ; @@ -4752,7 +5954,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1) &&POST_OPS_RELU_SCALE_4x1F, &&POST_OPS_GELU_TANH_4x1F, &&POST_OPS_GELU_ERF_4x1F, - &&POST_OPS_CLIP_4x1F + &&POST_OPS_CLIP_4x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x1F, + &&POST_OPS_SWISH_4x1F, + &&POST_OPS_MATRIX_MUL_4x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4825,7 +6031,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -4963,6 +6169,65 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,3,7); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-0] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-0] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-0] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x1F_DISABLE: ; @@ -4985,7 +6250,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1) &&POST_OPS_RELU_SCALE_3x1F, &&POST_OPS_GELU_TANH_3x1F, &&POST_OPS_GELU_ERF_3x1F, - &&POST_OPS_CLIP_3x1F + &&POST_OPS_CLIP_3x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x1F, + &&POST_OPS_SWISH_3x1F, + &&POST_OPS_MATRIX_MUL_3x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5050,7 +6319,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -5165,6 +6434,56 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,2,6); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-0] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-0] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x1F_DISABLE: ; @@ -5185,7 +6504,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1) &&POST_OPS_RELU_SCALE_2x1F, &&POST_OPS_GELU_TANH_2x1F, &&POST_OPS_GELU_ERF_2x1F, - &&POST_OPS_CLIP_2x1F + &&POST_OPS_CLIP_2x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x1F, + &&POST_OPS_SWISH_2x1F, + &&POST_OPS_MATRIX_MUL_2x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5247,7 +6570,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -5339,6 +6662,47 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,1,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-0] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x1F_DISABLE: ; @@ -5357,7 +6721,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1) &&POST_OPS_RELU_SCALE_1x1F, &&POST_OPS_GELU_TANH_1x1F, &&POST_OPS_GELU_ERF_1x1F, - &&POST_OPS_CLIP_1x1F + &&POST_OPS_CLIP_1x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x1F, + &&POST_OPS_SWISH_1x1F, + &&POST_OPS_MATRIX_MUL_1x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5407,7 +6775,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -5476,6 +6844,38 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x1F_DISABLE: ; diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h b/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h index 8fbdd78a8b..d800b5ae45 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,7 @@ #define LPGEMM_F32_SGEMM_AVX2_KERN_MACROS_H #include "../gelu_avx2.h" +#include "../silu_avx2.h" #include "../math_utils_avx2.h" /* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */ @@ -110,12 +111,12 @@ /*Load C, Multiply with beta and add with A*B and store*/ #define F32_C_BNZ_8(cbuf,rs_c,ymm0,beta,ymm2) \ - ymm0 = _mm256_load_ps(cbuf); \ + ymm0 = _mm256_loadu_ps(cbuf); \ ymm2 = _mm256_fmadd_ps(ymm0, beta, ymm2); \ /*Load C, Multiply with beta and add with A*B and store*/ #define F32_C_BNZ_4(cbuf,rs_c,xmm0,beta,xmm2) \ - xmm0 = _mm_load_ps(cbuf); \ + xmm0 = _mm_loadu_ps(cbuf); \ xmm2 = _mm_fmadd_ps(xmm0, beta, xmm2); \ /*Load C, Multiply with beta and add with A*B and store*/ @@ -128,4 +129,124 @@ xmm0 = _mm_load_ss(cbuf); \ xmm2 = _mm_fmadd_ps(xmm0, beta, xmm2); \ +// Matrix Add post-ops helper macros +#define F32_MATRIX_ADD_1COL_XMM(scr0,m_ind,r_ind0) \ + xmm ## r_ind0 = _mm_add_ps( scr0, xmm ## r_ind0 ); \ + +#define F32_MATRIX_ADD_1COL_YMM(scr0,m_ind,r_ind0) \ + ymm ## r_ind0 = _mm256_add_ps( scr0, ymm ## r_ind0 ); \ + +#define F32_MATRIX_ADD_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1) \ + ymm ## r_ind0 = _mm256_add_ps( scr0, ymm ## r_ind0 ); \ + ymm ## r_ind1 = _mm256_add_ps( scr1, ymm ## r_ind1 ); \ + +#define F32_F32_MATRIX_ADD_LOAD_XMM_1ELE(scr,m_ind,n_ind) \ + scr = ( __m128 )_mm_load_ss \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 2 ) \ + ); \ + +#define F32_F32_MATRIX_ADD_1COL_XMM_1ELE(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_ADD_LOAD_XMM_1ELE(scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_ADD_LOAD_XMM_2ELE(scr,m_ind,n_ind) \ + scr = ( __m128 )_mm_load_sd \ + ( \ + (double*)(matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 2 )) \ + ); \ + +#define F32_F32_MATRIX_ADD_1COL_XMM_2ELE(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_ADD_LOAD_XMM_2ELE(scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_ADD_LOAD_XMM(scr,m_ind,n_ind) \ + scr = _mm_loadu_ps \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 4 ) \ + ); \ + +#define F32_F32_MATRIX_ADD_1COL_XMM(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_ADD_LOAD_XMM(scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_ADD_LOAD_YMM(scr,m_ind,n_ind) \ + scr = _mm256_loadu_ps \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 8 ) \ + ); \ + +#define F32_F32_MATRIX_ADD_1COL(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_ADD_LOAD_YMM(scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL_YMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + F32_F32_MATRIX_ADD_LOAD_YMM(scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD_YMM(scr1,m_ind,1); \ + F32_MATRIX_ADD_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \ + +// Matrix Mul post-ops helper macros +#define F32_MATRIX_MUL_1COL_XMM(scr0,m_ind,r_ind0) \ + xmm ## r_ind0 = _mm_mul_ps( scr0, xmm ## r_ind0 ); \ + +#define F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0) \ + ymm ## r_ind0 = _mm256_mul_ps( scr0, ymm ## r_ind0 ); \ + +#define F32_MATRIX_MUL_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1) \ + ymm ## r_ind0 = _mm256_mul_ps( scr0, ymm ## r_ind0 ); \ + ymm ## r_ind1 = _mm256_mul_ps( scr1, ymm ## r_ind1 ); \ + +#define F32_F32_MATRIX_MUL_LOAD_XMM_1ELE(scr,m_ind,n_ind) \ + scr = ( __m128 )_mm_load_ss \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 2 ) \ + ); \ + +#define F32_F32_MATRIX_MUL_1COL_XMM_1ELE(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_MUL_LOAD_XMM_1ELE(scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_MUL_LOAD_XMM_2ELE(scr,m_ind,n_ind) \ + scr = ( __m128 )_mm_load_sd \ + ( \ + (double*)(matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 2 )) \ + ); \ + +#define F32_F32_MATRIX_MUL_1COL_XMM_2ELE(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_MUL_LOAD_XMM_2ELE(scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_MUL_LOAD_XMM(scr,m_ind,n_ind) \ + scr = _mm_loadu_ps \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 4 ) \ + ); \ + +#define F32_F32_MATRIX_MUL_1COL_XMM(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_MUL_LOAD_XMM(scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL_XMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_MUL_LOAD_YMM(scr,m_ind,n_ind) \ + scr = _mm256_loadu_ps \ + ( \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 8 ) \ + ); \ + +#define F32_F32_MATRIX_MUL_1COL(scr0,m_ind,r_ind0) \ + F32_F32_MATRIX_MUL_LOAD_YMM(scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \ + +#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + F32_F32_MATRIX_MUL_LOAD_YMM(scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD_YMM(scr1,m_ind,1); \ + F32_MATRIX_MUL_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \ + #endif //LPGEMM_F32_SGEMM_AVX2_KERN_MACROS_H diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c index a142a0fb3a..16cb8b0916 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,7 +52,11 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m) &&POST_OPS_RELU_SCALE_6x16F, &&POST_OPS_GELU_TANH_6x16F, &&POST_OPS_GELU_ERF_6x16F, - &&POST_OPS_CLIP_6x16F + &&POST_OPS_CLIP_6x16F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x16F, + &&POST_OPS_SWISH_6x16F, + &&POST_OPS_MATRIX_MUL_6x16F }; uint64_t n_left = n0 % NR; //n0 is expected to be n0<=NR @@ -555,6 +559,101 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,3,10,11); + + // c[4:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,4,12,13); + + // c[5:0-15] + F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,5,14,15); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x16F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,0,4,5); + + // c[1:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,1,6,7); + + // c[2:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,2,8,9); + + // c[3:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,3,10,11); + + // c[4:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,4,12,13); + + // c[5:0-15] + F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,5,14,15); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[0,8-15] + SWISH_F32_AVX2_DEF(ymm5, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,8-15] + SWISH_F32_AVX2_DEF(ymm7, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,8-15] + SWISH_F32_AVX2_DEF(ymm9, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,8-15] + SWISH_F32_AVX2_DEF(ymm11, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,0-7] + SWISH_F32_AVX2_DEF(ymm12, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,8-15] + SWISH_F32_AVX2_DEF(ymm13, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[5,0-7] + SWISH_F32_AVX2_DEF(ymm14, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[5,8-15] + SWISH_F32_AVX2_DEF(ymm15, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16F_DISABLE: ; @@ -625,7 +724,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m) &&POST_OPS_RELU_SCALE_6x8F, &&POST_OPS_GELU_TANH_6x8F, &&POST_OPS_GELU_ERF_6x8F, - &&POST_OPS_CLIP_6x8F + &&POST_OPS_CLIP_6x8F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x8F, + &&POST_OPS_SWISH_6x8F, + &&POST_OPS_MATRIX_MUL_6x8F }; // Typecast local copies of integers in case dim_t and inc_t are a @@ -907,6 +1010,83 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,3,10); + + // c[4:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,4,12); + + // c[5:0-7] + F32_F32_MATRIX_ADD_1COL(ymm1,5,14); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x8F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,0,4); + + // c[1:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,1,6); + + // c[2:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,2,8); + + // c[3:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,3,10); + + // c[4:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,4,12); + + // c[5:0-7] + F32_F32_MATRIX_MUL_1COL(ymm1,5,14); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x8F: + { + ymm0 = + _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m256 z, dn; + __m256i ex_out; + + // c[0,0-7] + SWISH_F32_AVX2_DEF(ymm4, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[1,0-7] + SWISH_F32_AVX2_DEF(ymm6, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[2,0-7] + SWISH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[3,0-7] + SWISH_F32_AVX2_DEF(ymm10, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[4,0-7] + SWISH_F32_AVX2_DEF(ymm12, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + // c[5,0-7] + SWISH_F32_AVX2_DEF(ymm14, ymm0, ymm1, ymm2, ymm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x8F_DISABLE: ; @@ -971,7 +1151,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m) &&POST_OPS_RELU_SCALE_6x4F, &&POST_OPS_GELU_TANH_6x4F, &&POST_OPS_GELU_ERF_6x4F, - &&POST_OPS_CLIP_6x4F + &&POST_OPS_CLIP_6x4F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x4F, + &&POST_OPS_SWISH_6x4F, + &&POST_OPS_MATRIX_MUL_6x4F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1250,6 +1434,83 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,2,6); + + // c[3:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,3,7); + + // c[4:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,4,8); + + // c[5:0-3] + F32_F32_MATRIX_ADD_1COL_XMM(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x4F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,0,4); + + // c[1:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,1,5); + + // c[2:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,2,6); + + // c[3:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,3,7); + + // c[4:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,4,8); + + // c[5:0-3] + F32_F32_MATRIX_MUL_1COL_XMM(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x4F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-3] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-3] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-3] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-3] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-3] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[5,0-3] + SWISH_F32_SSE_DEF(xmm9, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x4F_DISABLE: ; @@ -1314,7 +1575,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m) &&POST_OPS_RELU_SCALE_6x2F, &&POST_OPS_GELU_TANH_6x2F, &&POST_OPS_GELU_ERF_6x2F, - &&POST_OPS_CLIP_6x2F + &&POST_OPS_CLIP_6x2F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x2F, + &&POST_OPS_SWISH_6x2F, + &&POST_OPS_MATRIX_MUL_6x2F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1408,8 +1673,9 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = ( __m128 )_mm_load_sd( (const double*) + (( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 8 ) )); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -1593,6 +1859,83 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,3,7); + + // c[4:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,4,8); + + // c[5:0-1] + F32_F32_MATRIX_ADD_1COL_XMM_2ELE(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x2F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,0,4); + + // c[1:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,1,5); + + // c[2:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,2,6); + + // c[3:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,3,7); + + // c[4:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,4,8); + + // c[5:0-1] + F32_F32_MATRIX_MUL_1COL_XMM_2ELE(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x2F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-1] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-1] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-1] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-1] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-1] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[5,0-1] + SWISH_F32_SSE_DEF(xmm9, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x2F_DISABLE: ; @@ -1657,7 +2000,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m) &&POST_OPS_RELU_SCALE_6x1F, &&POST_OPS_GELU_TANH_6x1F, &&POST_OPS_GELU_ERF_6x1F, - &&POST_OPS_CLIP_6x1F + &&POST_OPS_CLIP_6x1F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x1F, + &&POST_OPS_SWISH_6x1F, + &&POST_OPS_MATRIX_MUL_6x1F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1751,7 +2098,7 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = _mm_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + xmm0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->op_args1 + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); // c[0,0-3] @@ -1936,6 +2283,83 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,3,7); + + // c[4:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,4,8); + + // c[5:0-0] + F32_F32_MATRIX_ADD_1COL_XMM_1ELE(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x1F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,0,4); + + // c[1:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,1,5); + + // c[2:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,2,6); + + // c[3:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,3,7); + + // c[4:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,4,8); + + // c[5:0-0] + F32_F32_MATRIX_MUL_1COL_XMM_1ELE(xmm1,5,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x1F: + { + xmm0 = + _mm_broadcast_ss( ( float* )post_ops_list_temp->op_args2 ); + __m128 z, dn; + __m128i ex_out; + + // c[0,0-0] + SWISH_F32_SSE_DEF(xmm4, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[1,0-0] + SWISH_F32_SSE_DEF(xmm5, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[2,0-0] + SWISH_F32_SSE_DEF(xmm6, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[3,0-0] + SWISH_F32_SSE_DEF(xmm7, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[4,0-0] + SWISH_F32_SSE_DEF(xmm8, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + // c[5,0-0] + SWISH_F32_SSE_DEF(xmm9, xmm0, xmm1, xmm2, xmm3, z, dn, ex_out) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x1F_DISABLE: ; diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c new file mode 100644 index 0000000000..1fecbc0518 --- /dev/null +++ b/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c @@ -0,0 +1,72 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32_avx2.h" + +void lpgemv_m_one_kernel_f32_avx2_ker_ft +( + const dim_t n0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t NR, + const dim_t KC, + const dim_t n_sub_updated, + const dim_t jc_cur_loop_rem, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ + // TODO: Created dummy function as place holder. + // AVX2 varient wil be implemented in next commits. + // Code will take LPGEMM path for LPGEMV in AVX2 env +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c new file mode 100644 index 0000000000..1dd118748a --- /dev/null +++ b/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32_avx2.h" + +// When n=1 is load 16x1 from B and load MRx16 from A and perform dot product +// to produce C output of MRX1. The vectorization is done in k loop and +// the horizontal reduction done to produce one output from each +// accumulator register +void lpgemv_n_one_kernel_f32_avx2_ker_ft +( + const dim_t m0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t MR, + const dim_t KC, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ +//TODO: Created dummy function as place holder to get +//rid of linking issues in other zen configurations. +//AVX2 varient wil be implemented in next commits. +//Code will take LPGEMM path for LPGEMV in AVX2 env. +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen/lpgemm/gelu_avx2.h b/kernels/zen/lpgemm/gelu_avx2.h index 3ee074e917..a14ff7cebc 100644 --- a/kernels/zen/lpgemm/gelu_avx2.h +++ b/kernels/zen/lpgemm/gelu_avx2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen/lpgemm/lpgemm_util_l1_ops_avx2.c b/kernels/zen/lpgemm/lpgemm_util_l1_ops_avx2.c index 2e9a1b5deb..704c6e9250 100644 --- a/kernels/zen/lpgemm/lpgemm_util_l1_ops_avx2.c +++ b/kernels/zen/lpgemm/lpgemm_util_l1_ops_avx2.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen/lpgemm/math_utils_avx2.h b/kernels/zen/lpgemm/math_utils_avx2.h index 5f503fa3e7..c26c07b188 100644 --- a/kernels/zen/lpgemm/math_utils_avx2.h +++ b/kernels/zen/lpgemm/math_utils_avx2.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c index c102a89dea..7a5cd212ad 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -50,7 +50,9 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) &&POST_OPS_GELU_TANH_6x32, &&POST_OPS_GELU_ERF_6x32, &&POST_OPS_CLIP_6x32, - &&POST_OPS_DOWNSCALE_6x32 + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32 }; dim_t MR = 6; @@ -107,7 +109,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) return; } - uint8_t cvt_uint8 = 128; + uint8_t cvt_uint8 = 128; __m256i vec_uint8 = _mm256_set1_epi8 (cvt_uint8); for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) @@ -146,9 +148,9 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - __m256i b0 = + __m256i b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); - __m256i b1 = + __m256i b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); // Separate register for intermediate op @@ -166,7 +168,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) _mm256_set1_epi16(*(int16_t *)(a + (rs_a * 1) + (cs_a * offset))); //convert signed int8 to uint8 for u8s8s16 FMA ops - a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); + a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -179,11 +181,11 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) c_int16_1p1 = _mm256_add_epi16(inter_vec, c_int16_1p1); // Broadcast a[2,kr:kr+2]. - a_int32_0 = + a_int32_0 = _mm256_set1_epi16(*(int16_t *)(a + (rs_a * 2) + (cs_a * offset))); //convert signed int8 to uint8 for u8s8s16 FMA ops - a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); + a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -195,11 +197,11 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) c_int16_2p1 = _mm256_add_epi16(inter_vec, c_int16_2p1); // Broadcast a[3,kr:kr+2]. - a_int32_0 = + a_int32_0 = _mm256_set1_epi16(*(int16_t *)(a + (rs_a * 3) + (cs_a * offset))); //convert signed int8 to uint8 for u8s8s16 FMA ops - a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); + a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -216,7 +218,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) _mm256_set1_epi16(*(int16_t *)(a + (rs_a * 4) + (cs_a * offset))); //convert signed int8 to uint8 for u8s8s16 FMA ops - a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); + a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -230,11 +232,11 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) c_int16_4p1 = _mm256_add_epi16(inter_vec, c_int16_4p1); // Broadcast a[5,kr:kr+2]. - a_int32_0 = + a_int32_0 = _mm256_set1_epi16(*(int16_t *)(a + (rs_a * 5) + (cs_a * offset))); //convert signed int8 to uint8 for u8s8s16 FMA ops - a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); + a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -355,7 +357,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) } if ( post_ops_attr.is_last_k == 1 ) { - //Subtract B matrix sum column values to compensate + //Subtract B matrix sum column values to compensate //for addition of 128 to A matrix elements int16_t* bsumptr = post_ops_attr.b_col_sum_vec_s16 + post_ops_attr.b_sum_offset; @@ -762,25 +764,44 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - __m128i _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -798,19 +819,38 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -828,6 +868,128 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x32: + { + __m256i selector1, selector2; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + alphav = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( alphav, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,16-31] + SWISH_S16_AVX2(c_int16_2p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,16-31] + SWISH_S16_AVX2(c_int16_3p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,16-31] + SWISH_S16_AVX2(c_int16_4p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,16-31] + SWISH_S16_AVX2(c_int16_5p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x32_DISABLE: @@ -898,7 +1060,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) // c[5,16-31] _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 1*16 )), c_int16_5p1 ); } - + a = a + ( MR * ps_a ); post_ops_attr.post_op_c_i += MR; } diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c index 8d5a99968c..b50892f432 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -53,7 +53,9 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) &&POST_OPS_GELU_TANH_4x32, &&POST_OPS_GELU_ERF_4x32, &&POST_OPS_CLIP_4x32, - &&POST_OPS_DOWNSCALE_4x32 + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32 }; // The division is done by considering the vpmaddubsw instruction @@ -520,26 +522,44 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -555,19 +575,38 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -583,6 +622,97 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,16-31] + SWISH_S16_AVX2(c_int16_2p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,16-31] + SWISH_S16_AVX2(c_int16_3p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -652,7 +782,9 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) &&POST_OPS_GELU_TANH_2x32, &&POST_OPS_GELU_ERF_2x32, &&POST_OPS_CLIP_2x32, - &&POST_OPS_DOWNSCALE_2x32 + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32 }; // The division is done by considering the vpmaddubsw instruction @@ -946,26 +1078,44 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -979,19 +1129,38 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1005,6 +1174,67 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -1055,7 +1285,9 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) &&POST_OPS_GELU_TANH_1x32, &&POST_OPS_GELU_ERF_1x32, &&POST_OPS_CLIP_1x32, - &&POST_OPS_DOWNSCALE_1x32 + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32 }; // The division is done by considering the vpmaddubsw instruction @@ -1262,26 +1494,44 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1294,19 +1544,38 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) // Scale first 16 columns of the 4 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1319,6 +1588,52 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) // Scale next 16 columns of the 4 rows. CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c index 9e2355a711..1e293048f8 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -53,7 +53,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) &&POST_OPS_GELU_TANH_4x16, &&POST_OPS_GELU_ERF_4x16, &&POST_OPS_CLIP_4x16, - &&POST_OPS_DOWNSCALE_4x16 + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16 }; // The division is done by considering the vpmaddubsw instruction @@ -383,26 +385,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -418,6 +438,85 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,3); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,3); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -470,7 +569,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) &&POST_OPS_GELU_TANH_4xlt16, &&POST_OPS_GELU_ERF_4xlt16, &&POST_OPS_CLIP_4xlt16, - &&POST_OPS_DOWNSCALE_4xlt16 + &&POST_OPS_DOWNSCALE_4xlt16, + &&POST_OPS_MATRIX_ADD_4xlt16, + &&POST_OPS_SWISH_4xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -824,36 +925,64 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) { - int8_t zero_point_buf[16]; + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -863,6 +992,85 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int8_t); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,uint8_t); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int16_t); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xlt16_DISABLE: @@ -935,7 +1143,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) &&POST_OPS_GELU_TANH_2x16, &&POST_OPS_GELU_ERF_2x16, &&POST_OPS_CLIP_2x16, - &&POST_OPS_DOWNSCALE_2x16 + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16 }; // The division is done by considering the vpmaddubsw instruction @@ -1157,26 +1367,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1190,6 +1418,61 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -1233,7 +1516,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) &&POST_OPS_GELU_TANH_2xlt16, &&POST_OPS_GELU_ERF_2xlt16, &&POST_OPS_CLIP_2xlt16, - &&POST_OPS_DOWNSCALE_2xlt16 + &&POST_OPS_DOWNSCALE_2xlt16, + &&POST_OPS_MATRIX_ADD_2xlt16, + &&POST_OPS_SWISH_2xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -1470,36 +1755,64 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { - int8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -1507,6 +1820,61 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xlt16_DISABLE: @@ -1562,7 +1930,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) &&POST_OPS_GELU_TANH_1x16, &&POST_OPS_GELU_ERF_1x16, &&POST_OPS_CLIP_1x16, - &&POST_OPS_DOWNSCALE_1x16 + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16 }; // The division is done by considering the vpmaddubsw instruction @@ -1729,26 +2099,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1761,6 +2149,49 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) // Scale first 16 columns of the 2 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -1802,7 +2233,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) &&POST_OPS_GELU_TANH_1xlt16, &&POST_OPS_GELU_ERF_1xlt16, &&POST_OPS_CLIP_1xlt16, - &&POST_OPS_DOWNSCALE_1xlt16 + &&POST_OPS_DOWNSCALE_1xlt16, + &&POST_OPS_MATRIX_ADD_1xlt16, + &&POST_OPS_SWISH_1xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -1981,42 +2414,113 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - int8_t zero_point_buf[16]; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } // Scale first 16 columns of the 2 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xlt16_DISABLE: diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c index 36cad252a6..b3997cb23e 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -54,7 +54,9 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) &&POST_OPS_GELU_TANH_6x16, &&POST_OPS_GELU_ERF_6x16, &&POST_OPS_CLIP_6x16, - &&POST_OPS_DOWNSCALE_6x16 + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16 }; dim_t m_full_pieces = m0 / MR; @@ -504,26 +506,44 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -541,6 +561,109 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,5); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,5); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x16_DISABLE: @@ -659,7 +782,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) &&POST_OPS_GELU_TANH_6xlt16, &&POST_OPS_GELU_ERF_6xlt16, &&POST_OPS_CLIP_6xlt16, - &&POST_OPS_DOWNSCALE_6xlt16 + &&POST_OPS_DOWNSCALE_6xlt16, + &&POST_OPS_MATRIX_ADD_6xlt16, + &&POST_OPS_SWISH_6xlt16 }; dim_t m_full_pieces = m0 / MR; @@ -1135,36 +1260,64 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { - int8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -1176,6 +1329,109 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int8_t); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int8_t); + + // c[4:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,int8_t); + + // c[5:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,uint8_t); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,uint8_t); + + // c[4:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,uint8_t); + + // c[5:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int16_t); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int16_t); + + // c[4:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,int16_t); + + // c[5:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xlt16_DISABLE: diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_packb_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_packb_amd256.c index 5fa9879a51..def8196d4c 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_packb_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_packb_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemv_n_kernel_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemv_n_kernel_amd256.c new file mode 100644 index 0000000000..832af96f4a --- /dev/null +++ b/kernels/zen/lpgemm/s8s8s16/lpgemv_n_kernel_amd256.c @@ -0,0 +1,855 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../u8s8s16/lpgemm_s16_kern_macros.h" + +#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \ + ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \ + ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \ + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \ + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); + +#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \ + inter_reg1, inter_reg2, c_reg1, c_reg2 ) \ + inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \ + c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \ + inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \ + c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2); + + +#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \ + ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \ + ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \ + ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \ + ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) ); \ + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \ + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); \ + ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); \ + ymm3 = _mm256_add_epi8( ymm3, vec_uint8 ); + +#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \ + inter_reg1, inter_reg2, \ + inter_reg3, inter_reg4, \ + out_reg1, out_reg2, out_reg3, out_reg4 ) \ + inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \ + out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \ + inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \ + out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \ + inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \ + out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \ + inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \ + out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4); + +#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \ + ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \ + ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \ + ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \ + xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \ + _mm256_extracti128_si256( ymm0, 1 ) ); + + + +LPGEMV_N_EQ1_KERN(int8_t, int8_t, int16_t, s8s8s16os16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_DISABLE, + &&POST_OPS_BIAS, + &&POST_OPS_RELU, + &&POST_OPS_RELU_SCALE, + &&POST_OPS_GELU_TANH, + &&POST_OPS_GELU_ERF, + &&POST_OPS_CLIP, + &&POST_OPS_DOWNSCALE, + &&POST_OPS_MATRIX_ADD, + &&POST_OPS_SWISH + }; + + int8_t *a_use = NULL; + int8_t *b_use = NULL; + int16_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + // temp buffer to store output C vector + int16_t ctemp[16]; + + // temp buffers to store a, b data in k_rem case. + int8_t buf0[32] = {0}; + int8_t buf1[32] = {0}; + int8_t buf2[32] = {0}; + int8_t buf3[32] = {0}; + int8_t buf4[32] = {0}; + int8_t buf5[32] = {0}; + int8_t buf6[32] = {0}; + int8_t buf7[32] = {0}; + int8_t buf8[32] = {0}; + + + uint8_t cvt_uint8 = 128; + __m256i vec_uint8; + + int16_t* bsumptr = post_ops_attr.b_col_sum_vec_s16; + + for ( dim_t ir = 0; ir < m0; ir += MR ) + { + dim_t mr0 = bli_min( ( m0 - ir ), MR ); + dim_t k_iter = k / 32; + dim_t k_rem = k % 32; + + __m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + __m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14; + __m256i ymm15; + + __m128i xmm0, xmm1; + + /* zero the accumulator registers */ + ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 ) + ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 ) + + //update pointers + a_use = (int8_t*)a + ir * rs_a; + b_use = (int8_t*)b; + c_use = (int16_t*)c + ir * rs_c; + + if( mr0 == MR ) + { + vec_uint8 = _mm256_set1_epi8 (cvt_uint8); + + for (dim_t k = 0; k < k_iter; k++) + { + + ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) ); + b_use += 32; + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + // Load 4x32 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, + ( a_use + 4 * rs_a ), rs_a + ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm12, ymm13, ymm14, ymm15 + ) + + a_use += 32; + } + + + + if( k_rem ) + { + uint8_t buf_vec_uint8_t[32] = {0}; + int8_t* restrict a0 = (a_use); + int8_t* restrict a1 = (a_use + rs_a ); + int8_t* restrict a2 = (a_use + 2 * rs_a ); + int8_t* restrict a3 = (a_use + 3 * rs_a ); + int8_t* restrict a4 = (a_use + 4 * rs_a ); + int8_t* restrict a5 = (a_use + 5 * rs_a ); + int8_t* restrict a6 = (a_use + 6 * rs_a ); + int8_t* restrict a7 = (a_use + 7 * rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + buf2[i] = a2[i]; + buf3[i] = a3[i]; + buf4[i] = a4[i]; + buf5[i] = a5[i]; + buf6[i] = a6[i]; + buf7[i] = a7[i]; + buf_vec_uint8_t[i] = cvt_uint8; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + vec_uint8 = _mm256_loadu_si256( ( __m256i const *) buf_vec_uint8_t ); + + //Load 4x32 elements from row0-row3 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 ); + + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); + ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); + ymm3 = _mm256_add_epi8( ymm3, vec_uint8 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + // Load 4x32 elements from row8-row11 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 ); + + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); + ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); + ymm3 = _mm256_add_epi8( ymm3, vec_uint8 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm12, ymm13, ymm14, ymm15 + ) + + } + //Add the registers horizantally to get one + LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 ) + LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 ) + + xmm0 = _mm_hadd_epi16( xmm0, xmm1 ); + + // post ops are applied on ymm register though + // second half of the register is filled with zeroes. + ymm8 = _mm256_setzero_si256(); + ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0); + + ymm0 = _mm256_set1_epi16( *bsumptr ); + ymm8 = _mm256_sub_epi16( ymm8, ymm0 ); + } + else + { + int8_t *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + if( mr0_use >= 4 ) + { + vec_uint8 = _mm256_set1_epi8 (cvt_uint8); + + for (dim_t k = 0; k < k_iter; k++) + { + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + b_use += 32; + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, + a_use, rs_a ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + a_use += 32; + } + + if( k_rem ) + { + uint8_t buf_vec_uint8_t[32] = {0}; + int8_t* restrict a0 = (a_use); + int8_t* restrict a1 = (a_use + rs_a ); + int8_t* restrict a2 = (a_use + 2 * rs_a ); + int8_t* restrict a3 = (a_use + 3 * rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + buf2[i] = a2[i]; + buf3[i] = a3[i]; + buf_vec_uint8_t[i] = cvt_uint8; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t ); + //Load 4xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 ); + + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); + ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); + ymm3 = _mm256_add_epi8( ymm3, vec_uint8 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = (int8_t*)b; + + //Add the registers horizantally to get one + LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 ) + + xmm0 = _mm_hadd_epi16( xmm0, xmm0 ); + + int64_t data = _mm_extract_epi64( xmm0, 0); + //insert xmm outputs into final output reg based on regidx + ymm8 = _mm256_setzero_si256(); + ymm8 = _mm256_insert_epi64( ymm8, data, 0 ); + regidx++; + } + + // Dot product for <= 3 + if ( mr0_use ) + { + // Dot product for m = 2 + if ( mr0_use >= 2 ) + { + vec_uint8 = _mm256_set1_epi8 (cvt_uint8); + + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + + LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a); + + LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4, + ymm5, ymm12, ymm13); + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + if ( k_rem ) + { + uint8_t buf_vec_uint8_t[32] = {0}; + int8_t* restrict a0 = (a_use); + int8_t* restrict a1 = (a_use + rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + buf_vec_uint8_t[i] = cvt_uint8; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t ); + //Load 2xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); + + LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, + ymm4, ymm5, ymm12, ymm13 ); + } + + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = (int8_t*)b; + } + + // Dot product for m = 1 + if ( mr0_use == 1 ) + { + vec_uint8 = _mm256_set1_epi8 (cvt_uint8); + + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + + // Load 1x32 elements from row0-row1 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)a_use ); + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + + ymm4 = _mm256_maddubs_epi16(ymm0, ymm6); + ymm14 = _mm256_add_epi16(ymm4, ymm14); + + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + if ( k_rem ) + { + uint8_t buf_vec_uint8_t[32] = {0}; + int8_t* restrict a0 = (a_use); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf_vec_uint8_t[i] = cvt_uint8; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t ); + + //Load 1xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); + + ymm4 = _mm256_maddubs_epi16(ymm0, ymm6); + ymm14 = _mm256_add_epi16(ymm4, ymm14); + } + + // When only fringe 1, + // update the registers to store in order + if ( !( mr0 & 0x2 ) ) ymm12 = ymm14; + } + + LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0) + xmm0 = _mm_hadd_epi16( xmm0, xmm0 ); + + int64_t data = _mm_extract_epi64( xmm0, 0); + //insert xmm outputs into final output reg based on regidx + + if( regidx == 0 ) + { + ymm8 = _mm256_insert_epi64( ymm8, data, 0 ); + } + else + { + ymm8 = _mm256_insert_epi64( ymm8, data, 1 ); + } + + } + + int16_t buf_vec_int16_t[16] = {0}; + for( dim_t i = 0; i < mr0; i++) + buf_vec_int16_t[i] = *bsumptr; + ymm0 = _mm256_loadu_si256( ( __m256i const *) buf_vec_int16_t); + ymm8 = _mm256_sub_epi16( ymm8, ymm0 ); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + ymm8 = _mm256_mullo_epi16(selector1, ymm8); + + if( beta != 0 ) + { + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + if( post_ops_attr.c_stor_type == S8 ) + { + dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t ); + + S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes ); + + S8_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + else if( post_ops_attr.c_stor_type == U8 ) + { + dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes ); + + U8_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + } + else + { + if( post_ops_attr.c_stor_type == S8 ) + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm256_cvtepi8_epi32 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + else if( post_ops_attr.c_stor_type == U8 ) + { + uint8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm256_cvtepu8_epi32 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + } + } + else + { + if( rs_c == 1 ) + { + dim_t m0_rem_bytes = mr0 * sizeof( int16_t ); + memcpy( ctemp, c_use, m0_rem_bytes ); + S16_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + else + { + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = c_use[ i * rs_c ]; + } + selector1 = _mm256_loadu_si256( (__m256i const *)ctemp ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + } + } + + // Post Ops + lpgemm_post_op * post_ops_list_temp = post_op; + + post_ops_attr.is_last_k = TRUE; + POST_OP_LABEL_LASTK_SAFE_JUMP + + + POST_OPS_BIAS: + { + + + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) ); + + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU: + { + selector1 = _mm256_setzero_si256(); + + ymm8 = _mm256_max_epi16( selector1, ymm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE: + { + __m256i b0; + selector1 = _mm256_setzero_si256(); + selector2 = _mm256_set1_epi16( + *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + RELU_SCALE_OP_S16_AVX2( ymm8 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH: + { + __m256 dn, z, x, r2, r, y1, y2, x_tanh; + __m256i q; + + GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF: + { + __m256 x, r, y1, y2, x_erf; + + GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP: + { + __m256i min = _mm256_set1_epi16( + *( int16_t* )post_ops_list_temp->op_args2 ); + __m256i max = _mm256_set1_epi16( + *( int16_t* )post_ops_list_temp->op_args3 ); + + CLIP_S16_AVX2(ymm8, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); + __m256i zero_point_0 = _mm256_setzero_si256(); + __m256 res_1, res_2; + + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + + // Scale first 16 columns of the 2 rows. + CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_MATRIX_ADD: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( int8_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(int8_t) + ); + selector1 = _mm256_cvtepi8_epi16( + _mm_loadu_si128( ( __m128i const* )ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_cvtepi8_epi16 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( uint8_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(uint8_t) + ); + selector1 = _mm256_cvtepu8_epi16( + _mm_loadu_si128( ( __m128i const* )ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + uint8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_cvtepu8_epi16 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( int16_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(int16_t) + ); + + selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp ); + + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + int16_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_loadu_si256( (__m256i const *)ctemp ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_SWISH: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1, + tmp_reg2, r, r2, z, dn, ex_out ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DISABLE: + { + if ( post_ops_attr.buf_downscale != NULL ) + { + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); + if( post_ops_attr.rs_c_downscale == 1 ) + { + if( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type + // (int8 instead of int16). + CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes); + } + else if( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t ); + + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes); + } + } + else + { + if( post_ops_attr.c_stor_type == S8 ) + { + int8_t ctemp[16]; + + CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp); + for( dim_t i = 0; i < mr0; i++ ) + { + *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + else if( post_ops_attr.c_stor_type == U8 ) + { + uint8_t ctemp[16]; + + CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + for( dim_t i = 0; i < mr0; i++ ) + { + *( ( uint8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + } + } + else + { + if( rs_c == 1 ) + { + _mm256_storeu_si256( ( __m256i* )ctemp, ymm8 ); + + dim_t m0_rem_bytes = mr0 * sizeof( int16_t ); + + memcpy( c_use, ctemp, m0_rem_bytes ); + } + else + { + _mm256_storeu_si256( ( __m256i* )ctemp, ymm8 ); + + for( dim_t i = 0; i < mr0; i++ ) + { + c_use[i * rs_c] = ctemp[i]; + } + } + } + + post_ops_attr.post_op_c_i += MR; + } + } +} + +#endif diff --git a/kernels/zen/lpgemm/silu_avx2.h b/kernels/zen/lpgemm/silu_avx2.h new file mode 100644 index 0000000000..1f88fecf52 --- /dev/null +++ b/kernels/zen/lpgemm/silu_avx2.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_LPGEMM_SWISH_AVX2_H +#define AOCL_LPGEMM_SWISH_AVX2_H + +// SiLU(in_reg) = in_reg / (1 + exp(-1 * al * in_reg)). +// in_reg and al are expected to contain float values. +#define SWISH_F32_AVX2_DEF(in_reg, al, al_in, r, r2, z, dn, ex_out) \ + al_in = _mm256_fnmadd_ps( in_reg, al, _mm256_setzero_ps() ); \ + EXPF_AVX2(al_in, r, r2, z, dn, ex_out); \ + ex_out = ( __m256i )_mm256_add_ps( ( __m256 )ex_out, _mm256_set1_ps( 1 ) ); \ + in_reg = _mm256_div_ps( in_reg, ( __m256 )ex_out ); \ + +// SiLU(in_reg) = in_reg / (1 + exp(-1 * al * in_reg)). +// in_reg and al are expected to contain float values. +#define SWISH_F32_SSE_DEF(in_reg, al, al_in, r, r2, z, dn, ex_out) \ + al_in = _mm_fnmadd_ps( in_reg, al, _mm_setzero_ps() ); \ + EXPF_SSE(al_in, r, r2, z, dn, ex_out); \ + ex_out = ( __m128i )_mm_add_ps( ( __m128 )ex_out, _mm_set1_ps( 1 ) ); \ + in_reg = _mm_div_ps( in_reg, ( __m128 )ex_out ); \ + +#endif // AOCL_LPGEMM_SWISH_AVX2_H diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 3c92c49da2..2be262ee46 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -50,7 +50,9 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) &&POST_OPS_GELU_TANH_6x32, &&POST_OPS_GELU_ERF_6x32, &&POST_OPS_CLIP_6x32, - &&POST_OPS_DOWNSCALE_6x32 + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32 }; dim_t MR = 6; @@ -738,25 +740,44 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); - scale_2 = - _mm256_loadu_ps( - ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 8 ) ); - - // Load zero points (2 byte values). - __m128i _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -774,19 +795,38 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -804,6 +844,128 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x32: + { + __m256i selector1, selector2; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + alphav = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( alphav, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,16-31] + SWISH_S16_AVX2(c_int16_2p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,16-31] + SWISH_S16_AVX2(c_int16_3p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,16-31] + SWISH_S16_AVX2(c_int16_4p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,16-31] + SWISH_S16_AVX2(c_int16_5p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x32_DISABLE: diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c index b6094c878d..c9f2d5ed64 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -53,7 +53,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) &&POST_OPS_GELU_TANH_4x32, &&POST_OPS_GELU_ERF_4x32, &&POST_OPS_CLIP_4x32, - &&POST_OPS_DOWNSCALE_4x32 + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32 }; // The division is done by considering the vpmaddubsw instruction @@ -501,26 +503,44 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -536,19 +556,38 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -564,6 +603,97 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,16-31] + SWISH_S16_AVX2(c_int16_2p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,16-31] + SWISH_S16_AVX2(c_int16_3p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -651,7 +781,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) &&POST_OPS_GELU_TANH_2x32, &&POST_OPS_GELU_ERF_2x32, &&POST_OPS_CLIP_2x32, - &&POST_OPS_DOWNSCALE_2x32 + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32 }; // The division is done by considering the vpmaddubsw instruction @@ -930,26 +1062,44 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -963,19 +1113,38 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -989,6 +1158,67 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,16-31] + SWISH_S16_AVX2(c_int16_1p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -1051,7 +1281,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) &&POST_OPS_GELU_TANH_1x32, &&POST_OPS_GELU_ERF_1x32, &&POST_OPS_CLIP_1x32, - &&POST_OPS_DOWNSCALE_1x32 + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32 }; // The division is done by considering the vpmaddubsw instruction @@ -1245,26 +1477,44 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1277,19 +1527,38 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) // Scale first 16 columns of the 4 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (2 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (3 * 8)); - - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1302,6 +1571,52 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) // Scale next 16 columns of the 4 rows. CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + U8_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S16_S16_MATRIX_ADD_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[0,16-31] + SWISH_S16_AVX2(c_int16_0p1, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c index b19abe413d..5aec94f2cd 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -53,7 +53,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) &&POST_OPS_GELU_TANH_4x16, &&POST_OPS_GELU_ERF_4x16, &&POST_OPS_CLIP_4x16, - &&POST_OPS_DOWNSCALE_4x16 + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16 }; // The division is done by considering the vpmaddubsw instruction @@ -359,26 +361,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -394,6 +414,85 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,3); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,3); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -460,7 +559,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) &&POST_OPS_GELU_TANH_4xlt16, &&POST_OPS_GELU_ERF_4xlt16, &&POST_OPS_CLIP_4xlt16, - &&POST_OPS_DOWNSCALE_4xlt16 + &&POST_OPS_DOWNSCALE_4xlt16, + &&POST_OPS_MATRIX_ADD_4xlt16, + &&POST_OPS_SWISH_4xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -797,36 +898,64 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - int8_t zero_point_buf[16]; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -836,6 +965,85 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int8_t); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,uint8_t); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int16_t); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xlt16_DISABLE: @@ -929,7 +1137,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) &&POST_OPS_GELU_TANH_2x16, &&POST_OPS_GELU_ERF_2x16, &&POST_OPS_CLIP_2x16, - &&POST_OPS_DOWNSCALE_2x16 + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16 }; // The division is done by considering the vpmaddubsw instruction @@ -1135,26 +1345,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1168,6 +1396,61 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -1222,7 +1505,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) &&POST_OPS_GELU_TANH_2xlt16, &&POST_OPS_GELU_ERF_2xlt16, &&POST_OPS_CLIP_2xlt16, - &&POST_OPS_DOWNSCALE_2xlt16 + &&POST_OPS_DOWNSCALE_2xlt16, + &&POST_OPS_MATRIX_ADD_2xlt16, + &&POST_OPS_SWISH_2xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -1448,36 +1733,64 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - int8_t zero_point_buf[16]; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -1485,6 +1798,61 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xlt16_DISABLE: @@ -1556,7 +1924,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) &&POST_OPS_GELU_TANH_1x16, &&POST_OPS_GELU_ERF_1x16, &&POST_OPS_CLIP_1x16, - &&POST_OPS_DOWNSCALE_1x16 + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16 }; // The division is done by considering the vpmaddubsw instruction @@ -1711,26 +2081,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -1743,6 +2131,49 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) // Scale first 16 columns of the 2 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -1796,7 +2227,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) &&POST_OPS_GELU_TANH_1xlt16, &&POST_OPS_GELU_ERF_1xlt16, &&POST_OPS_CLIP_1xlt16, - &&POST_OPS_DOWNSCALE_1xlt16 + &&POST_OPS_DOWNSCALE_1xlt16, + &&POST_OPS_MATRIX_ADD_1xlt16, + &&POST_OPS_SWISH_1xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -1967,42 +2400,113 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) { - int8_t zero_point_buf[16]; + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } // Scale first 16 columns of the 2 rows. CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xlt16_DISABLE: diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c index 1947de5542..c050627ff9 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -54,7 +54,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) &&POST_OPS_GELU_TANH_6x16, &&POST_OPS_GELU_ERF_6x16, &&POST_OPS_CLIP_6x16, - &&POST_OPS_DOWNSCALE_6x16 + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16 }; dim_t m_full_pieces = m0 / MR; @@ -471,26 +473,44 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - /* Load the scale vector values into the register*/ - scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); - scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); - - // Load zero points (2 byte values). - _zero_point_0 = - _mm_loadu_si128( - ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + scale_2 = _mm256_loadu_ps( + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + // Load zero points (2 byte values). + _zero_point_0 = _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } if ( post_ops_attr.c_stor_type == S8 ) { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); @@ -508,6 +528,109 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S8_S16_MATRIX_ADD_1COL(selector1,5); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + U8_S16_MATRIX_ADD_1COL(selector1,5); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S16_S16_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x16_DISABLE: @@ -643,7 +766,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) &&POST_OPS_GELU_TANH_6xlt16, &&POST_OPS_GELU_ERF_6xlt16, &&POST_OPS_CLIP_6xlt16, - &&POST_OPS_DOWNSCALE_6xlt16 + &&POST_OPS_DOWNSCALE_6xlt16, + &&POST_OPS_MATRIX_ADD_6xlt16, + &&POST_OPS_SWISH_6xlt16 }; dim_t m_full_pieces = m0 / MR; @@ -1095,36 +1220,64 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m128i _zero_point_0; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; - float float_buf[16]; + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + float float_buf[16] = { 0 }; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + /* Broadcast scale factor. */ + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - if ( post_ops_attr.c_stor_type == S8 ) + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - int8_t zero_point_buf[16]; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; - memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + } + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + // Broadcast zero point. + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + if ( post_ops_attr.c_stor_type == S8 ) + { zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); } else if ( post_ops_attr.c_stor_type == U8 ) { - uint8_t zero_point_buf[16]; - - memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); - _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); } @@ -1136,6 +1289,109 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6xlt16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int8_t); + + // c[1:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int8_t); + + // c[2:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int8_t); + + // c[3:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int8_t); + + // c[4:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,int8_t); + + // c[5:0-15] + S8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,int8_t); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,uint8_t); + + // c[1:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,uint8_t); + + // c[2:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,uint8_t); + + // c[3:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,uint8_t); + + // c[4:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,uint8_t); + + // c[5:0-15] + U8_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,uint8_t); + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,0,n0_rem,int16_t); + + // c[1:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,1,n0_rem,int16_t); + + // c[2:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,2,n0_rem,int16_t); + + // c[3:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,3,n0_rem,int16_t); + + // c[4:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,4,n0_rem,int16_t); + + // c[5:0-15] + S16_S16_MATRIX_ADD_1COL_PAR(buf0,selector1,5,n0_rem,int16_t); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xlt16: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + // c[0,0-15] + SWISH_S16_AVX2(c_int16_0p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[1,0-15] + SWISH_S16_AVX2(c_int16_1p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[2,0-15] + SWISH_S16_AVX2(c_int16_2p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[3,0-15] + SWISH_S16_AVX2(c_int16_3p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[4,0-15] + SWISH_S16_AVX2(c_int16_4p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + + // c[5,0-15] + SWISH_S16_AVX2(c_int16_5p0, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xlt16_DISABLE: diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_packa_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_packa_amd256.c new file mode 100644 index 0000000000..3394e1cfd3 --- /dev/null +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_packa_amd256.c @@ -0,0 +1,1314 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +void packa_mr16_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +void packa_u8s8s16os16 + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + if( ( cs == 1 ) && ( MC != 1 ) ) + { + // Not yet supported + } + else + { + packa_mr16_u8s8s16o16_col_major + ( pack_a_buffer_u8s8s16o16, a, rs, cs, MC, KC, rs_a, cs_a ); + } +} + +#define SET_REGISTERS_ZERO \ + a_reg[0] = _mm_setzero_si128(); \ + a_reg[1] = _mm_setzero_si128(); \ + a_reg[2] = _mm_setzero_si128(); \ + a_reg[3] = _mm_setzero_si128(); \ + a_reg[4] = _mm_setzero_si128(); \ + a_reg[5] = _mm_setzero_si128(); \ + a_reg[6] = _mm_setzero_si128(); \ + a_reg[7] = _mm_setzero_si128(); \ + a_reg[8] = _mm_setzero_si128(); \ + a_reg[9] = _mm_setzero_si128(); \ + a_reg[10] = _mm_setzero_si128(); \ + a_reg[11] = _mm_setzero_si128(); \ + a_reg[12] = _mm_setzero_si128(); \ + a_reg[13] = _mm_setzero_si128(); \ + a_reg[14] = _mm_setzero_si128(); \ + a_reg[15] = _mm_setzero_si128(); + +#define UNPACKLOW_EPI8 \ + b_reg[0] = _mm_unpacklo_epi8( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi8( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi8( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi8( a_reg[6], a_reg[7] ); \ + b_reg[4] = _mm_unpacklo_epi8( a_reg[8], a_reg[9] ); \ + b_reg[5] = _mm_unpacklo_epi8( a_reg[10], a_reg[11] ); \ + b_reg[6] = _mm_unpacklo_epi8( a_reg[12], a_reg[13] ); \ + b_reg[7] = _mm_unpacklo_epi8( a_reg[14], a_reg[15] ); + +#define UNPACKHI_EPI8 \ + b_reg[8] = _mm_unpackhi_epi8( a_reg[0], a_reg[1] ); \ + b_reg[9] = _mm_unpackhi_epi8( a_reg[2], a_reg[3] ); \ + b_reg[10] = _mm_unpackhi_epi8( a_reg[4], a_reg[5] ); \ + b_reg[11] = _mm_unpackhi_epi8( a_reg[6], a_reg[7] ); \ + b_reg[12] = _mm_unpackhi_epi8( a_reg[8], a_reg[9] ); \ + b_reg[13] = _mm_unpackhi_epi8( a_reg[10], a_reg[11] ); \ + b_reg[14] = _mm_unpackhi_epi8( a_reg[12], a_reg[13] ); \ + b_reg[15] = _mm_unpackhi_epi8( a_reg[14], a_reg[15] ); + +#define UNPACKLOW_EPI16 \ + a_reg[0] = _mm_unpacklo_epi16( b_reg[0], b_reg[1] ); \ + a_reg[1] = _mm_unpacklo_epi16( b_reg[2], b_reg[3] ); \ + a_reg[2] = _mm_unpacklo_epi16( b_reg[4], b_reg[5] ); \ + a_reg[3] = _mm_unpacklo_epi16( b_reg[6], b_reg[7] ); \ +\ + a_reg[8] = _mm_unpacklo_epi16( b_reg[8], b_reg[9] ); \ + a_reg[9] = _mm_unpacklo_epi16( b_reg[10], b_reg[11] ); \ + a_reg[10] = _mm_unpacklo_epi16( b_reg[12], b_reg[13] ); \ + a_reg[11] = _mm_unpacklo_epi16( b_reg[14], b_reg[15] ); + +#define UNPACKHI_EPI16 \ + a_reg[4] = _mm_unpackhi_epi16( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi16( b_reg[2], b_reg[3] ); \ + a_reg[6] = _mm_unpackhi_epi16( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi16( b_reg[6], b_reg[7] ); \ +\ + a_reg[12] = _mm_unpackhi_epi16( b_reg[8], b_reg[9] ); \ + a_reg[13] = _mm_unpackhi_epi16( b_reg[10], b_reg[11] ); \ + a_reg[14] = _mm_unpackhi_epi16( b_reg[12], b_reg[13] ); \ + a_reg[15] = _mm_unpackhi_epi16( b_reg[14], b_reg[15] ); + +#define UNPACKLOW_EPI32 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi32( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi32( a_reg[6], a_reg[7] ); \ +\ + b_reg[8] = _mm_unpacklo_epi32( a_reg[8], a_reg[9] ); \ + b_reg[9] = _mm_unpacklo_epi32( a_reg[10], a_reg[11] ); \ + b_reg[10] = _mm_unpacklo_epi32( a_reg[12], a_reg[13] ); \ + b_reg[11] = _mm_unpacklo_epi32( a_reg[14], a_reg[15] ); + +#define UNPACKHI_EPI32 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); \ + b_reg[6] = _mm_unpackhi_epi32( a_reg[4], a_reg[5] ); \ + b_reg[7] = _mm_unpackhi_epi32( a_reg[6], a_reg[7] ); \ +\ + b_reg[12] = _mm_unpackhi_epi32( a_reg[8], a_reg[9] ); \ + b_reg[13] = _mm_unpackhi_epi32( a_reg[10], a_reg[11] ); \ + b_reg[14] = _mm_unpackhi_epi32( a_reg[12], a_reg[13] ); \ + b_reg[15] = _mm_unpackhi_epi32( a_reg[14], a_reg[15] ); + +#define UNPACKLOW_EPI64 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[2] = _mm_unpacklo_epi64( b_reg[2], b_reg[3] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); \ + a_reg[6] = _mm_unpacklo_epi64( b_reg[6], b_reg[7]) ; \ +\ + a_reg[8] = _mm_unpacklo_epi64( b_reg[8], b_reg[9] ); \ + a_reg[10] = _mm_unpacklo_epi64( b_reg[10], b_reg[11] ); \ + a_reg[12] = _mm_unpacklo_epi64( b_reg[12], b_reg[13] ); \ + a_reg[14] = _mm_unpacklo_epi64( b_reg[14], b_reg[15] ); + +#define UNPACKHI_EPI64 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[3] = _mm_unpackhi_epi64( b_reg[2], b_reg[3] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi64( b_reg[6], b_reg[7] ); \ +\ + a_reg[9] = _mm_unpackhi_epi64( b_reg[8], b_reg[9] ); \ + a_reg[11] = _mm_unpackhi_epi64( b_reg[10], b_reg[11] ); \ + a_reg[13] = _mm_unpackhi_epi64( b_reg[12], b_reg[13] ); \ + a_reg[15] = _mm_unpackhi_epi64( b_reg[14], b_reg[15] ); + +#define UNPACKLOW_EPI16_MR8 \ + a_reg[0] = _mm_unpacklo_epi16( b_reg[0], b_reg[1] ); \ + a_reg[1] = _mm_unpacklo_epi16( b_reg[2], b_reg[3] ); \ + a_reg[2] = _mm_unpacklo_epi16( b_reg[4], b_reg[5] ); \ + a_reg[3] = _mm_unpacklo_epi16( b_reg[6], b_reg[7] ); + +#define UNPACKHI_EPI16_MR8 \ + a_reg[4] = _mm_unpackhi_epi16( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi16( b_reg[2], b_reg[3] ); \ + a_reg[6] = _mm_unpackhi_epi16( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi16( b_reg[6], b_reg[7] ); + +#define UNPACKLOW_EPI32_MR8 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi32( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi32( a_reg[6], a_reg[7] ); + +#define UNPACKHI_EPI32_MR8 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); \ + b_reg[6] = _mm_unpackhi_epi32( a_reg[4], a_reg[5] ); \ + b_reg[7] = _mm_unpackhi_epi32( a_reg[6], a_reg[7] ); + +#define UNPACKLOW_EPI64_MR8 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[2] = _mm_unpacklo_epi64( b_reg[2], b_reg[3] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); \ + a_reg[6] = _mm_unpacklo_epi64( b_reg[6], b_reg[7] ); + +#define UNPACKHI_EPI64_MR8 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[3] = _mm_unpackhi_epi64( b_reg[2], b_reg[3] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi64( b_reg[6], b_reg[7] ); + +#define UNPACKLOW_EPI32_MR4 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); + +#define UNPACKHI_EPI32_MR4 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); + +#define UNPACKLOW_EPI64_MR4 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); + +#define UNPACKHI_EPI64_MR4 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); + +#define MASKED_STORE_EPI32(mask) \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 0 ) * KC + kr ), mask, a_reg[0] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 1 ) * KC + kr ), mask, a_reg[1] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 2 ) * KC + kr ), mask, a_reg[4] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 3 ) * KC + kr ), mask, a_reg[5] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 4 ) * KC + kr ), mask, a_reg[2] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 5 ) * KC + kr ), mask, a_reg[3] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 6 ) * KC + kr ), mask, a_reg[6] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 7 ) * KC + kr ), mask, a_reg[7] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 8 ) * KC + kr ), mask, a_reg[8] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 9 ) * KC + kr ), mask, a_reg[9] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 10 ) * KC + kr ), mask, a_reg[12] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 11 ) * KC + kr ), mask, a_reg[13] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 12 ) * KC + kr ), mask, a_reg[10] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 13 ) * KC + kr ), mask, a_reg[11] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 14 ) * KC + kr ), mask, a_reg[14] ); \ + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( ic + 15 ) * KC + kr ), mask, a_reg[15] ); + +// Column-major transformation to row-major in blocks of MCxKC + +void packa_mr8_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ); + +void packa_mr4_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ); + +void packa_mrlt4_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC, + const dim_t m_left + ); + +void packa_mr16_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + dim_t mr = 16; + __m128i a_reg[16], b_reg[16]; + + dim_t m_partial_pieces = MC % mr; + dim_t k_partial_pieces = KC % 16; + dim_t m_left = MC % 4; + __m128i mask; + + SET_REGISTERS_ZERO + + dim_t ic, kr; + + for ( ic =0; ( ic + mr - 1 ) < MC; ic += mr ) + { + for ( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + a_reg[0] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + a_reg[4] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 4 ) * cs ) ) ); + a_reg[5] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 5 ) * cs ) ) ); + a_reg[6] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 6 ) * cs ) ) ); + a_reg[7] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 7 ) * cs ) ) ); + a_reg[8] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 8 ) * cs ) ) ); + a_reg[9] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 9 ) * cs ) ) ); + a_reg[10] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 10 ) * cs ) ) ); + a_reg[11] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 11 ) * cs ) ) ); + a_reg[12] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 12 ) * cs ) ) ); + a_reg[13] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 13 ) * cs ) ) ); + a_reg[14] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 14 ) * cs ) ) ); + a_reg[15] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 15 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 3 ) * KC + kr ), a_reg[5] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 4 ) * KC + kr ), a_reg[2] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 5 ) * KC + kr ), a_reg[3] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 6 ) * KC + kr ), a_reg[6] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 7 ) * KC + kr ), a_reg[7] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 8 ) * KC + kr ), a_reg[8] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 9 ) * KC + kr ), a_reg[9] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 10 ) * KC + kr ), a_reg[12] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 11 ) * KC + kr ), a_reg[13] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 12 ) * KC + kr ), a_reg[10] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 13 ) * KC + kr ), a_reg[11] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 14 ) * KC + kr ), a_reg[14] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( ic + 15 ) * KC + kr ), a_reg[15] ); + + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if (( kr + 7 ) < KC ) + { + a_reg[0] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + a_reg[4] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 4 ) * cs ) ) ); + a_reg[5] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 5 ) * cs ) ) ); + a_reg[6] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 6 ) * cs ) ) ); + a_reg[7] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 7 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + mask = _mm_set_epi32 (0, 0, -1, -1); + + MASKED_STORE_EPI32(mask); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + mask = _mm_set_epi32 (0, 0, 0, -1); + + MASKED_STORE_EPI32(mask); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 2 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+0) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+1) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+2) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+3) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[2] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+4) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[3] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+5) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[6] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+6) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[7] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+7) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[8] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+8) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[9] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+9) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[12] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+10) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[13] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+11) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[10] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+12) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[11] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+13) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[14] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+14) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[15] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+15) * KC + kr ), buf, n0_rem_bytes ); + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 1 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+0) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+1) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+2) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+3) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[2] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+4) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[3] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+5) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[6] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+6) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[7] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+7) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[8] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+8) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[9] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+9) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[12] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+10) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[13] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+11) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[10] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+12) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[11] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+13) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[14] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+14) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[15] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + (ic+15) * KC + kr ), buf, n0_rem_bytes ); + + kr += 1; + } + } + } + + if( m_partial_pieces > 0 ) + { + if ( ( ic + 8 - 1 ) < MC ) + { + packa_mr8_u8s8s16o16_col_major + ( + ( pack_a_buffer_u8s8s16o16 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC + ); + + ic += 8; + } + + if ( ( ic + 4 - 1 ) < MC ) + { + packa_mr4_u8s8s16o16_col_major + ( + ( pack_a_buffer_u8s8s16o16 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC + ); + + ic += 4; + } + + if ( m_left ) + { + packa_mrlt4_u8s8s16o16_col_major + ( + ( pack_a_buffer_u8s8s16o16 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC, m_left + ); + } + } + + *rs_a = KC; + *cs_a = 1; +} + +void packa_mr8_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ) +{ + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + + dim_t k_partial_pieces = KC % 16; + __m128i mask; + + SET_REGISTERS_ZERO + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + mask = _mm_set_epi32 (0, 0, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + a_reg[8] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 8 ) * cs ) ), mask ); + a_reg[9] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 9 ) * cs ) ), mask ); + a_reg[10] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 10 ) * cs ) ), mask ); + a_reg[11] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 11 ) * cs ) ), mask ); + a_reg[12] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 12 ) * cs ) ), mask ); + a_reg[13] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 13 ) * cs ) ), mask ); + a_reg[14] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 14 ) * cs ) ), mask ); + a_reg[15] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 15 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), a_reg[5] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 4 ) * KC + kr ), a_reg[2] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 5 ) * KC + kr ), a_reg[3] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 6 ) * KC + kr ), a_reg[6] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 7 ) * KC + kr ), a_reg[7] ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + mask = _mm_set_epi32 (0, 0, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), mask, a_reg[4] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), mask, a_reg[5] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 4 ) * KC + kr ), mask, a_reg[2] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 5 ) * KC + kr ), mask, a_reg[3] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 6 ) * KC + kr ), mask, a_reg[6] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 7 ) * KC + kr ), mask, a_reg[7] ); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + mask = _mm_set_epi32 (0, 0, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + mask = _mm_set_epi32 (0, 0, 0, -1); + + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), mask, a_reg[4] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), mask, a_reg[5] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 4 ) * KC + kr ), mask, a_reg[2] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 5 ) * KC + kr ), mask, a_reg[3] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 6 ) * KC + kr ), mask, a_reg[6] ); + _mm_maskstore_epi32( ( int* )( pack_a_buffer_u8s8s16o16 + ( 7 ) * KC + kr ), mask, a_reg[7] ); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + mask = _mm_set_epi32 (0, 0, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 2 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[2] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 4 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[3] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 5 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[6] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 6 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[7] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 7 ) * KC + kr ), buf, n0_rem_bytes ); + + kr += 2; + + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + mask = _mm_set_epi32 (0, 0, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 1 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[2] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 4 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[3] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 5 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[6] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 6 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[7] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 7 ) * KC + kr ), buf, n0_rem_bytes ); + + kr += 1; + } + } +} + + +void packa_mr4_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ) +{ + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + __m128i mask; + + SET_REGISTERS_ZERO + + dim_t k_partial_pieces = KC % 16; + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + a_reg[8] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 8 ) * cs ) ), mask ); + a_reg[9] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 9 ) * cs ) ), mask ); + a_reg[10] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 10 ) * cs ) ), mask ); + a_reg[11] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 11 ) * cs ) ), mask ); + a_reg[12] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 12 ) * cs ) ), mask ); + a_reg[13] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 13 ) * cs ) ), mask ); + a_reg[14] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 14 ) * cs ) ), mask ); + a_reg[15] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 15 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), a_reg[5] ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + mask = _mm_set_epi32 (0, 0, -1, -1); + + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), mask, a_reg[4] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), mask, a_reg[5] ); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + mask = _mm_set_epi32 (0, 0, 0, -1); + + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), mask, a_reg[4] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), mask, a_reg[5] ); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( (int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( (int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 2 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), buf, n0_rem_bytes ); + + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 1 * sizeof( uint8_t ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[5] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 3 ) * KC + kr ), buf, n0_rem_bytes ); + + kr += 1; + } + } +} + +void packa_mrlt4_u8s8s16o16_col_major + ( + uint8_t* pack_a_buffer_u8s8s16o16, + const uint8_t* a, + const dim_t cs, + const dim_t KC, + const dim_t m_left + ) +{ + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + __m128i mask; + + SET_REGISTERS_ZERO + + dim_t k_partial_pieces = KC % 16; + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + a_reg[8] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 8 ) * cs ) ), mask ); + a_reg[9] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 9 ) * cs ) ), mask ); + a_reg[10] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 10 ) * cs ) ), mask ); + a_reg[11] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 11 ) * cs ) ), mask ); + a_reg[12] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 12 ) * cs ) ), mask ); + a_reg[13] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 13 ) * cs ) ), mask ); + a_reg[14] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 14 ) * cs ) ), mask ); + a_reg[15] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 15 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), a_reg[4] ); + break; + + case 2: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), a_reg[1] ); + break; + + case 1: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), a_reg[0] ); + break; + } + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + a_reg[4] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 4 ) * cs ) ), mask ); + a_reg[5] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 5 ) * cs ) ), mask ); + a_reg[6] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 6 ) * cs ) ), mask ); + a_reg[7] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 7 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + mask = _mm_set_epi32 (0, 0, -1, -1); + + switch( m_left ) + { + case 3: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (0) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (1) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (2) * KC + kr ), mask, a_reg[4] ); + break; + + case 2: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (0) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (1) * KC + kr ), mask, a_reg[1] ); + break; + + case 1: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + (0) * KC + kr ), mask, a_reg[0] ); + break; + } + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + a_reg[2] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 2 ) * cs ) ), mask ); + a_reg[3] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 3 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + mask = _mm_set_epi32 (0, 0, 0, -1); + + switch( m_left ) + { + case 3: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), mask, a_reg[4] ); + break; + + case 2: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), mask, a_reg[1] ); + break; + + case 1: + _mm_maskstore_epi32( ( int* ) ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), mask, a_reg[0] ); + break; + } + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs ) ), mask ); + a_reg[1] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 1 ) * cs ) ), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 2 * sizeof( uint8_t ); + + switch( m_left ) + { + case 3: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + + case 2: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + + case 1: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + } + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + mask = _mm_set_epi32 (0, -1, -1, -1); + + a_reg[0] = _mm_maskload_epi32 ( ( int const* ) ( a + ( ( kr + 0 ) * cs )), mask ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + uint8_t buf[16]; + dim_t n0_rem_bytes = 1 * sizeof( uint8_t ); + + switch( m_left ) + { + case 3: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[4] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 2 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + + case 2: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + _mm_storeu_si128( ( __m128i* )buf, a_reg[1] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 1 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + + case 1: + _mm_storeu_si128( ( __m128i* )buf, a_reg[0] ); + memcpy( ( pack_a_buffer_u8s8s16o16 + ( 0 ) * KC + kr ), buf, n0_rem_bytes ); + + break; + } + + kr += 1; + } + } +} + + +#endif diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c index 1169f825c8..1841ef6451 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h b/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h index 48a95ccd53..e2b8c20e16 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #define LPGEMM_S16_KERN_MACROS_H #include "../gelu_avx2.h" +#include "../silu_avx2.h" #include "../math_utils_avx2.h" #define S8_MIN (-128) @@ -129,7 +130,7 @@ #define U8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,uint8_t) \ - + // Downscale macro #define CVT_MULRND_CVT16(reg, scale0, scale1, zero_point_0) \ \ @@ -350,4 +351,132 @@ \ reg = _mm256_min_epi16( _mm256_max_epi16( reg, min ), max ); \ +// Matrix Add post-ops helper macros +#define S16_MATRIX_ADD_1COL(scr0,m_ind) \ + c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \ + +#define S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \ + c_int16_ ## m_ind ## p1 = _mm256_add_epi16( scr1, c_int16_ ## m_ind ## p1 ); \ + +#define S8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \ + scr = _mm256_cvtepi8_epi16 \ + ( \ + _mm_loadu_si128 \ + ( \ + ( __m128i const* ) \ + ( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \ + ) \ + ); \ + +#define S8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \ + memcpy \ + ( \ + ( OTYPE* )buf, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( 0 * 16 ), \ + ( n_rem ) * sizeof(OTYPE) \ + ); \ + scr0 = _mm256_cvtepi8_epi16 \ + ( \ + _mm_loadu_si128( ( __m128i const* )buf ) \ + ); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S8_S16_MATRIX_ADD_1COL(scr0,m_ind) \ + S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + S8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \ + S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define U8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \ + scr = _mm256_cvtepu8_epi16 \ + ( \ + _mm_loadu_si128 \ + ( \ + ( __m128i const* ) \ + ( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \ + ) \ + ); \ + +#define U8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \ + memcpy \ + ( \ + ( OTYPE* )buf, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( 0 * 16 ), \ + ( n_rem ) * sizeof(OTYPE) \ + ); \ + scr0 = _mm256_cvtepu8_epi16 \ + ( \ + _mm_loadu_si128( ( __m128i const* )buf ) \ + ); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define U8_S16_MATRIX_ADD_1COL(scr0,m_ind) \ + U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define U8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + U8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \ + S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define S16_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \ + scr = _mm256_loadu_si256 \ + ( \ + (__m256i const *) \ + ( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \ + ); \ + +#define S16_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \ + memcpy \ + ( \ + ( OTYPE* )buf, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( 0 * 16 ), \ + ( n_rem ) * sizeof(OTYPE) \ + ); \ + scr0 = _mm256_loadu_si256( ( __m256i const* )buf ); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S16_S16_MATRIX_ADD_1COL(scr0,m_ind) \ + S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + S16_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S16_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \ + S16_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \ + S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +// SiLU utility macros. al1, al2 register expected to contain floats. +#define SWISH_S16_AVX2(in_reg, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out) \ +\ + tmp_reg1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( in_reg, 0 ) ) ); \ + tmp_reg2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( in_reg, 1 ) ) ); \ +\ + SWISH_F32_AVX2_DEF(tmp_reg1, al, al_in, r, r2, z, dn, ex_out); \ +\ + SWISH_F32_AVX2_DEF(tmp_reg2, al, al_in, r, r2, z, dn, ex_out); \ +\ + in_reg = _mm256_packs_epi32(_mm256_cvtps_epi32(tmp_reg1), _mm256_cvtps_epi32(tmp_reg2));\ + in_reg = _mm256_permute4x64_epi64(in_reg, 0XD8);\ + + +//Zero-out the given YMM accumulator registers +#define ZERO_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3) \ + ymm0 = _mm256_setzero_si256 (); \ + ymm1 = _mm256_setzero_si256 (); \ + ymm2 = _mm256_setzero_si256 (); \ + ymm3 = _mm256_setzero_si256 (); + + #endif //LPGEMM_S16_KERN_MACROS_H diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemv_n_kernel_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemv_n_kernel_amd256.c new file mode 100644 index 0000000000..e8fdfdebe6 --- /dev/null +++ b/kernels/zen/lpgemm/u8s8s16/lpgemv_n_kernel_amd256.c @@ -0,0 +1,793 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_s16_kern_macros.h" + +#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \ + ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \ + ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); + +#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \ + inter_reg1, inter_reg2, c_reg1, c_reg2 ) \ + inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \ + c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \ + inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \ + c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2); + + +#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \ + ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \ + ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \ + ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \ + ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) ); + +#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \ + inter_reg1, inter_reg2, \ + inter_reg3, inter_reg4, \ + out_reg1, out_reg2, out_reg3, out_reg4 ) \ + inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \ + out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \ + inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \ + out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \ + inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \ + out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \ + inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \ + out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4); + +#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \ + ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \ + ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \ + ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \ + xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \ + _mm256_extracti128_si256( ymm0, 1 ) ); + + + +LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int16_t, u8s8s16os16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_DISABLE, + &&POST_OPS_BIAS, + &&POST_OPS_RELU, + &&POST_OPS_RELU_SCALE, + &&POST_OPS_GELU_TANH, + &&POST_OPS_GELU_ERF, + &&POST_OPS_CLIP, + &&POST_OPS_DOWNSCALE, + &&POST_OPS_MATRIX_ADD, + &&POST_OPS_SWISH + }; + + uint8_t *a_use = NULL; + int8_t *b_use = NULL; + int16_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + // temp buffer to store output C vector + int16_t ctemp[16]; + + // temp buffers to store a, b data in k_rem case. + uint8_t buf0[32] = {0}; + uint8_t buf1[32] = {0}; + uint8_t buf2[32] = {0}; + uint8_t buf3[32] = {0}; + uint8_t buf4[32] = {0}; + uint8_t buf5[32] = {0}; + uint8_t buf6[32] = {0}; + uint8_t buf7[32] = {0}; + int8_t buf8[32] = {0}; + + for ( dim_t ir = 0; ir < m0; ir += MR ) + { + dim_t mr0 = bli_min( ( m0 - ir ), MR ); + dim_t k_iter = k / 32; + dim_t k_rem = k % 32; + + __m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + __m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14; + __m256i ymm15; + + __m128i xmm0, xmm1; + + /* zero the accumulator registers */ + ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 ) + ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 ) + + //update pointers + a_use = (uint8_t*)a + ir * rs_a; + b_use = (int8_t*)b; + c_use = (int16_t*)c + ir * rs_c; + + if( mr0 == MR ) + { + for (dim_t k = 0; k < k_iter; k++) + { + + ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) ); + b_use += 32; + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + // Load 4x32 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, + ( a_use + 4 * rs_a ), rs_a + ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm12, ymm13, ymm14, ymm15 + ) + + a_use += 32; + } + + + + if( k_rem ) + { + + uint8_t* restrict a0 = (a_use); + uint8_t* restrict a1 = (a_use + rs_a ); + uint8_t* restrict a2 = (a_use + 2 * rs_a ); + uint8_t* restrict a3 = (a_use + 3 * rs_a ); + uint8_t* restrict a4 = (a_use + 4 * rs_a ); + uint8_t* restrict a5 = (a_use + 5 * rs_a ); + uint8_t* restrict a6 = (a_use + 6 * rs_a ); + uint8_t* restrict a7 = (a_use + 7 * rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + buf2[i] = a2[i]; + buf3[i] = a3[i]; + buf4[i] = a4[i]; + buf5[i] = a5[i]; + buf6[i] = a6[i]; + buf7[i] = a7[i]; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + //Load 4x32 elements from row0-row3 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + // Load 4x32 elements from row8-row11 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm12, ymm13, ymm14, ymm15 + ) + + } + //Add the registers horizantally to get one + LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 ) + LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 ) + + xmm0 = _mm_hadd_epi16( xmm0, xmm1 ); + + // post ops are applied on ymm register though + // second half of the register is filled with zeroes. + ymm8 = _mm256_setzero_si256(); + ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0); + } + else + { + uint8_t *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + if( mr0_use >= 4 ) + { + for (dim_t k = 0; k < k_iter; k++) + { + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + b_use += 32; + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, + a_use, rs_a ) + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + + a_use += 32; + } + + if( k_rem ) + { + uint8_t* restrict a0 = (a_use); + uint8_t* restrict a1 = (a_use + rs_a ); + uint8_t* restrict a2 = (a_use + 2 * rs_a ); + uint8_t* restrict a3 = (a_use + 3 * rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + buf2[i] = a2[i]; + buf3[i] = a3[i]; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + //Load 4xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 ); + ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 ); + + LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3, + ymm6, ymm4, ymm5, ymm7, ymm4, + ymm8, ymm9, ymm10, ymm11 + ) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = (int8_t*)b; + + //Add the registers horizantally to get one + LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 ) + + xmm0 = _mm_hadd_epi16( xmm0, xmm0 ); + + int64_t data = _mm_extract_epi64( xmm0, 0); + //insert xmm outputs into final output reg based on regidx + ymm8 = _mm256_setzero_si256(); + ymm8 = _mm256_insert_epi64( ymm8, data, 0 ); + regidx++; + } + + // Dot product for <= 3 + if ( mr0_use ) + { + // Dot product for m = 2 + if ( mr0_use >= 2 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + + LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a); + + LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4, + ymm5, ymm12, ymm13); + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + if ( k_rem ) + { + uint8_t* restrict a0 = (a_use); + uint8_t* restrict a1 = (a_use + rs_a ); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + buf1[i] = a1[i]; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + //Load 2xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 ); + + LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, + ymm4, ymm5, ymm12, ymm13 ); + } + + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = (int8_t*)b; + } + + // Dot product for m = 1 + if ( mr0_use == 1 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + ymm6 = _mm256_loadu_si256( (__m256i const *)b_use ); + + // Load 1x32 elements from row0-row1 of A + ymm0 = _mm256_loadu_si256( (__m256i const *)a_use ); + + ymm4 = _mm256_maddubs_epi16(ymm0, ymm6); + ymm14 = _mm256_add_epi16(ymm4, ymm14); + + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + if ( k_rem ) + { + uint8_t* restrict a0 = (a_use); + + for( dim_t i = 0; i < k_rem; i++) + { + buf8[i] = b_use[i]; + buf0[i] = a0[i]; + } + ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 ); + + //Load 1xk_rem elements from row0-row3 of A + + ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 ); + + ymm4 = _mm256_maddubs_epi16(ymm0, ymm6); + ymm14 = _mm256_add_epi16(ymm4, ymm14); + } + + // When only fringe 1, + // update the registers to store in order + if ( !( mr0 & 0x2 ) ) ymm12 = ymm14; + } + + LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0) + xmm0 = _mm_hadd_epi16( xmm0, xmm0 ); + + int64_t data = _mm_extract_epi64( xmm0, 0); + //insert xmm outputs into final output reg based on regidx + + if( regidx == 0 ) + { + ymm8 = _mm256_insert_epi64( ymm8, data, 0 ); + } + else + { + ymm8 = _mm256_insert_epi64( ymm8, data, 1 ); + } + + } + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + ymm8 = _mm256_mullo_epi16(selector1, ymm8); + + if( beta != 0 ) + { + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + if( post_ops_attr.c_stor_type == S8 ) + { + dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t ); + + S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes ); + + S8_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + else if( post_ops_attr.c_stor_type == U8 ) + { + dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes ); + + U8_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + } + else + { + if( post_ops_attr.c_stor_type == S8 ) + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm256_cvtepi8_epi32 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + else if( post_ops_attr.c_stor_type == U8 ) + { + uint8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm256_cvtepu8_epi32 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + } + } + else + { + if( rs_c == 1 ) + { + dim_t m0_rem_bytes = mr0 * sizeof( int16_t ); + memcpy( ctemp, c_use, m0_rem_bytes ); + S16_S16_BETA_OP_NLT16( ymm8, ctemp, + selector1, selector2 ) + } + else + { + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = c_use[ i * rs_c ]; + } + selector1 = _mm256_loadu_si256( (__m256i const *)ctemp ); + S16_BETA_FMA( ymm8, selector1, selector2 ); + } + } + } + + // Post Ops + lpgemm_post_op * post_ops_list_temp = post_op; + + post_ops_attr.is_last_k = TRUE; + POST_OP_LABEL_LASTK_SAFE_JUMP + + + POST_OPS_BIAS: + { + + + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) ); + + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU: + { + selector1 = _mm256_setzero_si256(); + + ymm8 = _mm256_max_epi16( selector1, ymm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE: + { + __m256i b0; + selector1 = _mm256_setzero_si256(); + selector2 = _mm256_set1_epi16( + *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + RELU_SCALE_OP_S16_AVX2( ymm8 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH: + { + __m256 dn, z, x, r2, r, y1, y2, x_tanh; + __m256i q; + + GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF: + { + __m256 x, r, y1, y2, x_erf; + + GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP: + { + __m256i min = _mm256_set1_epi16( + *( int16_t* )post_ops_list_temp->op_args2 ); + __m256i max = _mm256_set1_epi16( + *( int16_t* )post_ops_list_temp->op_args3 ); + + CLIP_S16_AVX2(ymm8, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1 = _mm256_setzero_ps(); + __m256 scale_2 = _mm256_setzero_ps(); + __m128i _zero_point_0 = _mm_setzero_si128(); + __m256i zero_point_0 = _mm256_setzero_si256(); + __m256 res_1, res_2; + + scale_1 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + + scale_2 = + _mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + + _zero_point_0 = _mm_set1_epi8( + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + + // Scale first 16 columns of the 2 rows. + CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_MATRIX_ADD: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( int8_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(int8_t) + ); + selector1 = _mm256_cvtepi8_epi16( + _mm_loadu_si128( ( __m128i const* )ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_cvtepi8_epi16 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( uint8_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(uint8_t) + ); + selector1 = _mm256_cvtepu8_epi16( + _mm_loadu_si128( ( __m128i const* )ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + uint8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_cvtepu8_epi16 + ( _mm_loadu_si128( (__m128i const*)ctemp ) ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + else + { + int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + memcpy + ( + ( int16_t* )ctemp, + matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + ( mr0 ) * sizeof(int16_t) + ); + + selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp ); + + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + else + { + int32_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm256_loadu_si256( (__m256i const *)ctemp ); + ymm8 = _mm256_add_epi16( selector1, ymm8 ); + } + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_SWISH: + { + selector1 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \ + _mm256_extractf128_si256( selector1, 0 ) ) ); + + __m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn; + __m256i ex_out; + + SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1, + tmp_reg2, r, r2, z, dn, ex_out ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DISABLE: + { + if ( post_ops_attr.buf_downscale != NULL ) + { + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); + if( post_ops_attr.rs_c_downscale == 1 ) + { + if( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type + // (int8 instead of int16). + CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes); + } + else if( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t ); + + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0, + m0_rem_dscale_bytes); + } + } + else + { + if( post_ops_attr.c_stor_type == S8 ) + { + int8_t ctemp[16]; + + CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp); + for( dim_t i = 0; i < mr0; i++ ) + { + *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + else if( post_ops_attr.c_stor_type == U8 ) + { + uint8_t ctemp[16]; + + CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp); + + for( dim_t i = 0; i < mr0; i++ ) + { + *( ( uint8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + } + } + else + { + if( rs_c == 1 ) + { + _mm256_storeu_si256( ( __m256i* )ctemp, ymm8 ); + + dim_t m0_rem_bytes = mr0 * sizeof( int16_t ); + + memcpy( c_use, ctemp, m0_rem_bytes ); + } + else + { + _mm256_storeu_si256( ( __m256i* )ctemp, ymm8 ); + + for( dim_t i = 0; i < mr0; i++ ) + { + c_use[i * rs_c] = ctemp[i]; + } + } + } + + post_ops_attr.post_op_c_i += MR; + } + } +} + +#endif diff --git a/kernels/zen4/1/bli_addv_zen_int_avx512.c b/kernels/zen4/1/bli_addv_zen_int_avx512.c new file mode 100644 index 0000000000..6ac6c36c1e --- /dev/null +++ b/kernels/zen4/1/bli_addv_zen_int_avx512.c @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +void bli_daddv_zen_int_avx512 + ( + conj_t conjx, + dim_t n, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + __m512d yv[8]; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + double *x0 = x; + double *y0 = y; + + if ( incx == 1 && incy ==1 ) + { + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // the copy operation will be done for the multiples of 64 + for ( ; i < (n & (~0x3F)); i += 64 ) + { + // Loading input values + yv[0] = _mm512_loadu_pd( y0 ); + yv[1] = _mm512_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm512_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm512_loadu_pd( y0 + 3*num_elem_per_reg ); + yv[4] = _mm512_loadu_pd( y0 + 4*num_elem_per_reg ); + yv[5] = _mm512_loadu_pd( y0 + 5*num_elem_per_reg ); + yv[6] = _mm512_loadu_pd( y0 + 6*num_elem_per_reg ); + yv[7] = _mm512_loadu_pd( y0 + 7*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 ), + yv[0] + ); + yv[1] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + yv[4] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 4*num_elem_per_reg ), + yv[4] + ); + yv[5] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 5*num_elem_per_reg ), + yv[5] + ); + yv[6] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 6*num_elem_per_reg ), + yv[6] + ); + yv[7] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 7*num_elem_per_reg ), + yv[7] + ); + + _mm512_storeu_pd( y0, yv[0] ); + _mm512_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm512_storeu_pd( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm512_storeu_pd( ( y0 + 3*num_elem_per_reg ), yv[3] ); + _mm512_storeu_pd( ( y0 + 4*num_elem_per_reg ), yv[4] ); + _mm512_storeu_pd( ( y0 + 5*num_elem_per_reg ), yv[5] ); + _mm512_storeu_pd( ( y0 + 6*num_elem_per_reg ), yv[6] ); + _mm512_storeu_pd( ( y0 + 7*num_elem_per_reg ), yv[7] ); + + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x1F)); i += 32 ) + { + // Loading input values + yv[0] = _mm512_loadu_pd( y0 ); + yv[1] = _mm512_loadu_pd( y0 + 1*num_elem_per_reg ); + yv[2] = _mm512_loadu_pd( y0 + 2*num_elem_per_reg ); + yv[3] = _mm512_loadu_pd( y0 + 3*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 ), + yv[0] + ); + yv[1] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + yv[2] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 2*num_elem_per_reg ), + yv[2] + ); + yv[3] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 3*num_elem_per_reg ), + yv[3] + ); + + _mm512_storeu_pd( y0, yv[0] ); + _mm512_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + _mm512_storeu_pd( ( y0 + 2*num_elem_per_reg ), yv[2] ); + _mm512_storeu_pd( ( y0 + 3*num_elem_per_reg ), yv[3] ); + + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x0F)); i += 16 ) + { + // Loading input values + yv[0] = _mm512_loadu_pd( y0 ); + yv[1] = _mm512_loadu_pd( y0 + 1*num_elem_per_reg ); + + // y := y + x + yv[0] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 ), + yv[0] + ); + yv[1] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 + 1*num_elem_per_reg ), + yv[1] + ); + + _mm512_storeu_pd( y0, yv[0] ); + _mm512_storeu_pd( ( y0 + 1*num_elem_per_reg ), yv[1] ); + + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for ( ; i < (n & (~0x07)); i += 8 ) + { + // Loading input values + yv[0] = _mm512_loadu_pd( y0 ); + + // y := y + x + yv[0] = _mm512_add_pd + ( + _mm512_loadu_pd( x0 ), + yv[0] + ); + + _mm512_storeu_pd( y0, yv[0] ); + + x0 += 1 * num_elem_per_reg; + y0 += 1 * num_elem_per_reg; + } + + // Handling the frine case + if ( i < n ) + { + // Setting the mask for loading and storing the vectors + __mmask8 n_mask = (1 << ( n - i )) - 1; + + // Loading input values + yv[0] = _mm512_maskz_loadu_pd( n_mask, y0 ); + + // y := y + x + yv[0] = _mm512_add_pd + ( + _mm512_maskz_loadu_pd( n_mask, x0 ), + yv[0] + ); + + _mm512_mask_storeu_pd( y0, n_mask, yv[0] ); + } + } + + else + { + // Handling fringe cases or non-unit strided vectors + for ( ; i < n; i += 1 ) + { + *y0 += *x0; + + x0 += incx; + y0 += incy; + } + } +} diff --git a/kernels/zen4/1/bli_axpbyv_zen_int_avx512.c b/kernels/zen4/1/bli_axpbyv_zen_int_avx512.c new file mode 100644 index 0000000000..8026b6a6dd --- /dev/null +++ b/kernels/zen4/1/bli_axpbyv_zen_int_avx512.c @@ -0,0 +1,463 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* One 512-bit AVX register holds 8 DP elements */ +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalars. + */ +void bli_daxpbyv_zen_int_avx512 + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // Redirecting to other L1 kernels based on alpha and beta values + // If alpha is 0, we call DSCALV + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When alpha = 0 : + // When beta = 0 --> DSETV + // When beta = 1 --> Early return + // When beta = !( 0 or 1 ) --> DSCALV + if ( bli_deq0( *alpha ) ) + { + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n, + beta, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 0, we call DSCAL2V + // This kernel would further reroute based on few other combinations + // of alpha and beta. They are as follows : + // When beta = 0 : + // When alpha = 0 --> DSETV + // When alpha = 1 --> DCOPYV + // When alpha = !( 0 or 1 ) --> DSCAL2V + else if ( bli_deq0( *beta ) ) + { + bli_dscal2v_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // If beta is 1, we have 2 scenarios for rerouting + // When alpha = 1 --> DADDV + // When alpha = !( 0 or 1 ) --> DAXPYV + else if ( bli_deq1( *beta ) ) + { + if( bli_deq1( *alpha ) ) + { + bli_daddv_zen_int + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + } + else + { + bli_daxpyv_zen_int + ( + conjx, + n, + alpha, + x, incx, + y, incy, + cntx + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i = 0; // iterator + + // Local pointer aliases to the parameters + double* restrict x0; + double* restrict y0; + + // Registers to load/store the vectors + v8df_t alphav; + v8df_t betav; + v8df_t yv[8]; + + // Boolean to check for alpha being 1 + bool is_alpha_one = bli_seq1( *alpha ); + + // Initialize local pointers + x0 = x; + y0 = y; + + if( incx == 1 && incy == 1 ) + { + // Broadcasting beta onto a ZMM register + betav.v = _mm512_set1_pd( *beta ); + + if( is_alpha_one ) // Scale y with beta and add x to it + { + for( ; i + 63 < n; i += 64 ) + { + // Loading Y vector onto 8 registers + // Thus, we iterate in blocks of 64 elements + yv[0].v = _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + yv[2].v = _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ); + yv[3].v = _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ); + yv[4].v = _mm512_loadu_pd( x0 + 4 * n_elem_per_reg ); + yv[5].v = _mm512_loadu_pd( x0 + 5 * n_elem_per_reg ); + yv[6].v = _mm512_loadu_pd( x0 + 6 * n_elem_per_reg ); + yv[7].v = _mm512_loadu_pd( x0 + 7 * n_elem_per_reg ); + + // Loading Y vector and using it as part of beta scaling and adding to X + yv[0].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ), yv[1].v ); + yv[2].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 2 * n_elem_per_reg ), yv[2].v ); + yv[3].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 3 * n_elem_per_reg ), yv[3].v ); + yv[4].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 4 * n_elem_per_reg ), yv[4].v ); + yv[5].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 5 * n_elem_per_reg ), yv[5].v ); + yv[6].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 6 * n_elem_per_reg ), yv[6].v ); + yv[7].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 7 * n_elem_per_reg ), yv[7].v ); + + // Storing the results onto Y vector + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, yv[2].v ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, yv[3].v ); + _mm512_storeu_pd( y0 + 4 * n_elem_per_reg, yv[4].v ); + _mm512_storeu_pd( y0 + 5 * n_elem_per_reg, yv[5].v ); + _mm512_storeu_pd( y0 + 6 * n_elem_per_reg, yv[6].v ); + _mm512_storeu_pd( y0 + 7 * n_elem_per_reg, yv[7].v ); + + // Adjusting the pointers + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + for( ; i + 31 < n; i += 32 ) + { + // Loading Y vector onto 4 registers + // Thus, we iterate in blocks of 32 elements + yv[0].v = _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + yv[2].v = _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ); + yv[3].v = _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ); + + // Loading Y vector and using it as part of beta scaling and adding to X + yv[0].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ), yv[1].v ); + yv[2].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 2 * n_elem_per_reg ), yv[2].v ); + yv[3].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 3 * n_elem_per_reg ), yv[3].v ); + + // Storing the results onto Y vector + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, yv[2].v ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, yv[3].v ); + + // Adjusting the pointers + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + for( ; i + 15 < n; i += 16 ) + { + // Loading Y vector onto 2 registers + // Thus, we iterate in blocks of 16 elements + yv[0].v = _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + + // Loading Y vector and using it as part of beta scaling and adding to X + yv[0].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ), yv[1].v ); + + // Storing the results onto Y vector + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + + // Adjusting the pointers + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + for( ; i + 7 < n; i += 8 ) + { + // Loading Y vector onto 1 register + // Thus, we iterate in blocks of 8 elements + yv[0].v = _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ); + + // Loading Y vector and using it as part of beta scaling and adding to X + yv[0].v = _mm512_fmadd_pd( betav.v, _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ), yv[0].v ); + + // Storing the results onto Y vector + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + + // Adjusting the pointers + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Handling the fringe cases + if( i < n ) + { + // Setting the mask for loading and storing the vectors + __mmask8 n_mask = (1 << (n - i)) - 1; + + // Loading the X vector + yv[0].v = _mm512_maskz_loadu_pd( n_mask, x0 + 0 * n_elem_per_reg ); + + // Loading Y vector and using it as part of beta scaling and adding to X + yv[0].v = _mm512_fmadd_pd( betav.v, _mm512_maskz_loadu_pd( n_mask, y0 + 0 * n_elem_per_reg ), yv[0].v ); + + // Storing the results onto Y vector + _mm512_mask_storeu_pd( y0 + 0 * n_elem_per_reg, n_mask, yv[0].v ); + + } + } + else + { + // Broadcasting alpha onto a ZMM register + alphav.v = _mm512_set1_pd( *alpha ); + for( ; i + 63 < n; i += 64 ) + { + // Loading X vector onto 8 registers + // Thus, we iterate in blocks of 64 elements + yv[0].v = _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2].v = _mm512_loadu_pd( y0 + 2 * n_elem_per_reg ); + yv[3].v = _mm512_loadu_pd( y0 + 3 * n_elem_per_reg ); + yv[4].v = _mm512_loadu_pd( y0 + 4 * n_elem_per_reg ); + yv[5].v = _mm512_loadu_pd( y0 + 5 * n_elem_per_reg ); + yv[6].v = _mm512_loadu_pd( y0 + 6 * n_elem_per_reg ); + yv[7].v = _mm512_loadu_pd( y0 + 7 * n_elem_per_reg ); + + // Beta scaling Y vector + yv[0].v = _mm512_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm512_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm512_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm512_mul_pd( betav.v, yv[3].v ); + yv[4].v = _mm512_mul_pd( betav.v, yv[4].v ); + yv[5].v = _mm512_mul_pd( betav.v, yv[5].v ); + yv[6].v = _mm512_mul_pd( betav.v, yv[6].v ); + yv[7].v = _mm512_mul_pd( betav.v, yv[7].v ); + + // Loading X vector and using it as part of alpha scaling and adding to Y + yv[0].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ), yv[1].v ); + yv[2].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ), yv[2].v ); + yv[3].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ), yv[3].v ); + yv[4].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 4 * n_elem_per_reg ), yv[4].v ); + yv[5].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 5 * n_elem_per_reg ), yv[5].v ); + yv[6].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 6 * n_elem_per_reg ), yv[6].v ); + yv[7].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 7 * n_elem_per_reg ), yv[7].v ); + + // Storing the result onto Y + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, yv[2].v ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, yv[3].v ); + _mm512_storeu_pd( y0 + 4 * n_elem_per_reg, yv[4].v ); + _mm512_storeu_pd( y0 + 5 * n_elem_per_reg, yv[5].v ); + _mm512_storeu_pd( y0 + 6 * n_elem_per_reg, yv[6].v ); + _mm512_storeu_pd( y0 + 7 * n_elem_per_reg, yv[7].v ); + + // Adjusting the pointers + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + for( ; i + 31 < n; i += 32 ) + { + // Loading X vector onto 4 registers + // Thus, we iterate in blocks of 32 elements + yv[0].v = _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2].v = _mm512_loadu_pd( y0 + 2 * n_elem_per_reg ); + yv[3].v = _mm512_loadu_pd( y0 + 3 * n_elem_per_reg ); + + // Beta scaling Y vector + yv[0].v = _mm512_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm512_mul_pd( betav.v, yv[1].v ); + yv[2].v = _mm512_mul_pd( betav.v, yv[2].v ); + yv[3].v = _mm512_mul_pd( betav.v, yv[3].v ); + + // Loading X vector and using it as part of alpha scaling and adding to Y + yv[0].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ), yv[1].v ); + yv[2].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ), yv[2].v ); + yv[3].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ), yv[3].v ); + + // Storing the result onto Y + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, yv[2].v ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, yv[3].v ); + + // Adjusting the pointers + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + for( ; i + 15 < n; i += 16 ) + { + // Loading X vector onto 2 registers + // Thus, we iterate in blocks of 16 elements + yv[0].v = _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ); + yv[1].v = _mm512_loadu_pd( y0 + 1 * n_elem_per_reg ); + + // Beta scaling Y vector + yv[0].v = _mm512_mul_pd( betav.v, yv[0].v ); + yv[1].v = _mm512_mul_pd( betav.v, yv[1].v ); + + // Loading X vector and using it as part of alpha scaling and adding to Y + yv[0].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ), yv[0].v ); + yv[1].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ), yv[1].v ); + + // Storing the result onto Y + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, yv[1].v ); + + // Adjusting the pointers + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + for( ; i + 7 < n; i += 8 ) + { + // Loading X vector onto 1 register + // Thus, we iterate in blocks of 8 elements + yv[0].v = _mm512_loadu_pd( y0 + 0 * n_elem_per_reg ); + + // Beta scaling Y vector + yv[0].v = _mm512_mul_pd( betav.v, yv[0].v ); + + // Loading X vector and using it as part of alpha scaling and adding to Y + yv[0].v = _mm512_fmadd_pd( alphav.v, _mm512_loadu_pd( x0 + 0 * n_elem_per_reg ), yv[0].v ); + + // Storing the result onto Y + _mm512_storeu_pd( y0 + 0 * n_elem_per_reg, yv[0].v ); + + // Adjusting the pointers + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Handling the fringe cases + if( i < n ) + { + // Setting the mask to load/store the remaining elements + __mmask8 n_mask = (1 << (n - i)) - 1; + + // Loading Y vector + yv[0].v = _mm512_maskz_loadu_pd( n_mask, y0 + 0 * n_elem_per_reg ); + + // Beta scaling Y vector + yv[0].v = _mm512_mul_pd( betav.v, yv[0].v ); + + // Loading X vector and using it as part of alpha scaling and adding to Y + yv[0].v = _mm512_fmadd_pd( alphav.v, _mm512_maskz_loadu_pd( n_mask, x0 + 0 * n_elem_per_reg ), yv[0].v ); + + // Storing the result onto Y + _mm512_mask_storeu_pd( y0 + 0 * n_elem_per_reg, n_mask, yv[0].v ); + + } + } + } + else + { + if( is_alpha_one ) + { + for ( ; i < n; ++i ) + { + *y0 = (*beta) * (*y0) + (*x0); + + x0 += incx; + y0 += incy; + } + } + else + { + for ( ; i < n; ++i ) + { + *y0 = (*beta) * (*y0) + (*alpha) * (*x0); + + x0 += incx; + y0 += incy; + } + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} diff --git a/kernels/zen4/1/bli_axpyv_zen_int_avx512.c b/kernels/zen4/1/bli_axpyv_zen_int_avx512.c index 181a5a38ee..dce35c9ee0 100644 --- a/kernels/zen4/1/bli_axpyv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_axpyv_zen_int_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -282,7 +282,7 @@ void bli_saxpyv_zen_int_avx512 The expectation is that these are standard BLAS exceptions and should be handled in a higher layer */ -void bli_daxpyv_zen_int_avx512 +BLIS_EXPORT_BLIS void bli_daxpyv_zen_int_avx512 ( conj_t conjx, dim_t n, @@ -445,3 +445,323 @@ void bli_daxpyv_zen_int_avx512 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } + +// ----------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function calculates y := y + alpha * x where all three variables are of type + double. + + Function Signature + ------------------- + + This function takes three float pointer as input, the correspending vector's stride + and length. It uses the function parameters to return the output. + + * 'conjx' - Info about conjugation of x (This variable is not used in the kernel) + * 'n' - Length of the array passed + * 'alpha' - Double pointer to a scalar value + * 'x' - Double pointer to an array + * 'incx' - Stride to point to the next element in the array + * 'y' - Double pointer to an array + * 'incy' - Stride to point to the next element in the array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n <= 0, incx <= 0 and incy <= 0. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ +void bli_zaxpyv_zen_int_avx512 + ( + conj_t conjx, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const int n_elem_per_reg = 8; + + dim_t i = 0; + + // Initialize local pointers. + double *restrict x0 = (double *)x; + double *restrict y0 = (double *)y; + + if (incx == 1 && incy == 1) + { + __m512d xv[8], yv[8], alphaRv, alphaIv; + + // Broadcast real and imag parts of alpha to separate registers + alphaRv = _mm512_set1_pd(alpha->real); + alphaIv = _mm512_set1_pd(alpha->imag); + + xv[0] = _mm512_setzero_pd(); + + // Handle X conjugate by negating some elements of alphaRv/alphaIv + if ( bli_is_noconj( conjx ) ) + alphaIv = _mm512_fmaddsub_pd(xv[0], xv[0], alphaIv); + else + alphaRv = _mm512_fmsubadd_pd(xv[0], xv[0], alphaRv); + + // To check if code has to go to masked load/store directly + if ( n >= 4 ) + { + for (; (i + 31) < n; i += 32) + { + // Loading elements from X + xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm512_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + // Loading elements from Y + yv[0] = _mm512_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm512_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm512_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm512_loadu_pd(y0 + 3 * n_elem_per_reg); + + // Scale X with real-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaRv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaRv, xv[1], yv[1]); + yv[2] = _mm512_fmadd_pd(alphaRv, xv[2], yv[2]); + yv[3] = _mm512_fmadd_pd(alphaRv, xv[3], yv[3]); + + // Swapping real and imag parts of every element in X + xv[0] = _mm512_permute_pd(xv[0], 0x55); + xv[1] = _mm512_permute_pd(xv[1], 0x55); + xv[2] = _mm512_permute_pd(xv[2], 0x55); + xv[3] = _mm512_permute_pd(xv[3], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaIv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaIv, xv[1], yv[1]); + yv[2] = _mm512_fmadd_pd(alphaIv, xv[2], yv[2]); + yv[3] = _mm512_fmadd_pd(alphaIv, xv[3], yv[3]); + + // Store updated Y + _mm512_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + _mm512_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]); + _mm512_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]); + _mm512_storeu_pd((y0 + 3 * n_elem_per_reg), yv[3]); + + // Loading elements from X + xv[4] = _mm512_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm512_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm512_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm512_loadu_pd(x0 + 7 * n_elem_per_reg); + + // Loading elements from Y + yv[4] = _mm512_loadu_pd(y0 + 4 * n_elem_per_reg); + yv[5] = _mm512_loadu_pd(y0 + 5 * n_elem_per_reg); + yv[6] = _mm512_loadu_pd(y0 + 6 * n_elem_per_reg); + yv[7] = _mm512_loadu_pd(y0 + 7 * n_elem_per_reg); + + // Scale X with real-part of alpha and add to Y + yv[4] = _mm512_fmadd_pd(alphaRv, xv[4], yv[4]); + yv[5] = _mm512_fmadd_pd(alphaRv, xv[5], yv[5]); + yv[6] = _mm512_fmadd_pd(alphaRv, xv[6], yv[6]); + yv[7] = _mm512_fmadd_pd(alphaRv, xv[7], yv[7]); + + // Swapping real and imag parts of every element in X + xv[4] = _mm512_permute_pd(xv[4], 0x55); + xv[5] = _mm512_permute_pd(xv[5], 0x55); + xv[6] = _mm512_permute_pd(xv[6], 0x55); + xv[7] = _mm512_permute_pd(xv[7], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[4] = _mm512_fmadd_pd(alphaIv, xv[4], yv[4]); + yv[5] = _mm512_fmadd_pd(alphaIv, xv[5], yv[5]); + yv[6] = _mm512_fmadd_pd(alphaIv, xv[6], yv[6]); + yv[7] = _mm512_fmadd_pd(alphaIv, xv[7], yv[7]); + + // Store updated Y + _mm512_storeu_pd((y0 + 4 * n_elem_per_reg), yv[4]); + _mm512_storeu_pd((y0 + 5 * n_elem_per_reg), yv[5]); + _mm512_storeu_pd((y0 + 6 * n_elem_per_reg), yv[6]); + _mm512_storeu_pd((y0 + 7 * n_elem_per_reg), yv[7]); + + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + for (; (i + 15) < n; i += 16) + { + // Loading elements from X + xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm512_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + // Loading elements from Y + yv[0] = _mm512_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm512_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm512_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm512_loadu_pd(y0 + 3 * n_elem_per_reg); + + // Scale X with real-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaRv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaRv, xv[1], yv[1]); + yv[2] = _mm512_fmadd_pd(alphaRv, xv[2], yv[2]); + yv[3] = _mm512_fmadd_pd(alphaRv, xv[3], yv[3]); + + // Swapping real and imag parts of every element in X + xv[0] = _mm512_permute_pd(xv[0], 0x55); + xv[1] = _mm512_permute_pd(xv[1], 0x55); + xv[2] = _mm512_permute_pd(xv[2], 0x55); + xv[3] = _mm512_permute_pd(xv[3], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaIv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaIv, xv[1], yv[1]); + yv[2] = _mm512_fmadd_pd(alphaIv, xv[2], yv[2]); + yv[3] = _mm512_fmadd_pd(alphaIv, xv[3], yv[3]); + + // Store updated Y + _mm512_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + _mm512_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]); + _mm512_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]); + _mm512_storeu_pd((y0 + 3 * n_elem_per_reg), yv[3]); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + for (; (i + 7) < n; i += 8) + { + // Loading elements from X + xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm512_loadu_pd(x0 + 1 * n_elem_per_reg); + + // Loading elements from Y + yv[0] = _mm512_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm512_loadu_pd(y0 + 1 * n_elem_per_reg); + + // Scale X with real-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaRv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaRv, xv[1], yv[1]); + + // Swapping real and imag parts of every element in X + xv[0] = _mm512_permute_pd(xv[0], 0x55); + xv[1] = _mm512_permute_pd(xv[1], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaIv, xv[0], yv[0]); + yv[1] = _mm512_fmadd_pd(alphaIv, xv[1], yv[1]); + + // Store updated Y + _mm512_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + _mm512_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + for (; (i + 3) < n; i += 4) + { + // Loading elements from X + xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); + + // Loading elements from Y + yv[0] = _mm512_loadu_pd(y0 + 0 * n_elem_per_reg); + + // Scale X with real-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaRv, xv[0], yv[0]); + + // Swapping real and imag parts of every element in X + xv[0] = _mm512_permute_pd(xv[0], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaIv, xv[0], yv[0]); + + // Store updated Y + _mm512_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + + x0 += n_elem_per_reg; + y0 += n_elem_per_reg; + + } + } + + if ( i < n ) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(n-i) elements. + __mmask8 n_mask = (1 << 2*(n - i)) - 1; + + // Loading elements from X + xv[0] = _mm512_maskz_loadu_pd(n_mask, x0); + + // Loading elements from Y + yv[0] = _mm512_maskz_loadu_pd(n_mask, y0); + + // Scale X with real-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaRv, xv[0], yv[0]); + + // Swapping real and imag parts of every element in X + xv[0] = _mm512_permute_pd(xv[0], 0x55); + + // Scale X with imag-part of alpha and add to Y + yv[0] = _mm512_fmadd_pd(alphaIv, xv[0], yv[0]); + + // Store updated Y + _mm512_mask_storeu_pd(y0, n_mask, yv[0]); + } + } + else + { + __m128d xv, yv, temp, alphaRv, alphaIv; + + alphaRv = _mm_loaddup_pd((double *)alpha); + alphaIv = _mm_loaddup_pd((double *)alpha + 1); + + xv = _mm_setzero_pd(); + + if (bli_is_noconj(conjx)) + alphaIv = _mm_addsub_pd(xv, alphaIv); + else + { + alphaRv = _mm_addsub_pd(xv, alphaRv); + alphaRv = _mm_shuffle_pd(alphaRv, alphaRv, 0x01); + } + + for (; i < n; i += 1) + { + xv = _mm_loadu_pd(x0); + yv = _mm_loadu_pd(y0); + + temp = _mm_shuffle_pd(xv, xv, 0x01); + + temp = _mm_mul_pd(alphaIv, temp); + xv = _mm_mul_pd(alphaRv, xv); + + xv = _mm_add_pd(xv, temp); + yv = _mm_add_pd(yv, xv); + + _mm_storeu_pd(y0, yv); + + x0 += 2 * incx; + y0 += 2 * incy; + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} diff --git a/kernels/zen4/1/bli_copyv_zen4_asm_avx512.c b/kernels/zen4/1/bli_copyv_zen4_asm_avx512.c new file mode 100644 index 0000000000..02ccc9eed4 --- /dev/null +++ b/kernels/zen4/1/bli_copyv_zen4_asm_avx512.c @@ -0,0 +1,1764 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "immintrin.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +// -------------------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a vector x to a vector y for + type float. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Float pointer pointing to an array + * 'y' - Float pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to jthe next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_scopyv_zen4_asm_avx512 +( + conj_t conjx, + dim_t n, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + + // Initialize local pointers. + float *x0 = x; + float *y0 = y; + + // Typecast int to 64 bit + uint64_t n0 = (uint64_t)n; + int64_t incy0 = (int64_t)incy; + int64_t incx0 = (int64_t)incx; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + // Assembly Code + begin_asm() + + /* + rsi - > n + rdx - > x + rcx - > incx + r8 - > y + r9 - > incy + */ + + // Loading the source memory address to the respective registers + mov(var(x0), rdx) + mov(var(y0), r8) + + // Loading the values in 'n', 'incx' and 'incy' to the respective registers + mov(var(n0), rsi) + mov(var(incx0), rcx) + mov(var(incy0), r9) + + // Checking if incx == 1 and incy == 1, incase the condition fails then SCALAR code section is executed + cmp(imm(1),rcx) + jne(.SCALAR) + cmp(imm(1),r9) + jne(.SCALAR) + + // ======================================================================================================================== + + // Section of code to move the data as blocks of 256 elements + label(.BLOCK256) + + cmp(imm(16*16), rsi) // check if the number of remaining elements greater than or equal to 256 + jl(.BLOCK128) // else, goto to the section of code for block of size 128 + + label(.MAINLOOP) + + // Interleaved SIMD load and store operations to copy data from source to the destination + // Each vector register can hold 16 elements and is used twice before next jump operation + // 1 for loading the element from source and 1 for store it into the destination + + vmovups(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+15] + vmovups(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+15] = zmm0 + vmovups(mem(rdx, 1*64), zmm1) // zmm1 = x[i+16] - x[i+31] + vmovups(zmm1, mem(r8, 1*64)) // y[i+16] - y[i+31] = zmm1 + vmovups(mem(rdx, 2*64), zmm2) // zmm2 = x[i+32] - x[i+47] + vmovups(zmm2, mem(r8, 2*64)) // y[i+32] - y[i+47] = zmm2 + vmovups(mem(rdx, 3*64), zmm3) // zmm3 = x[i+48] - x[i+63] + vmovups(zmm3, mem(r8, 3*64)) // y[i+48] - y[i+63] = zmm3 + + vmovups(mem(rdx, 4*64), zmm4) // zmm4 = x[i+64] - x[i+79] + vmovups(zmm4, mem(r8, 4*64)) // y[i+64] - y[i+79] = zmm4 + vmovups(mem(rdx, 5*64), zmm5) // zmm5 = x[i+80] - x[i+95] + vmovups(zmm5, mem(r8, 5*64)) // y[i+80] - y[i+95] = zmm5 + vmovups(mem(rdx, 6*64), zmm6) // zmm6 = x[i+96] - x[i+111] + vmovups(zmm6, mem(r8, 6*64)) // y[i+96] - y[i+111] = zmm6 + vmovups(mem(rdx, 7*64), zmm7) // zmm7 = x[i+112] - x[i+127] + vmovups(zmm7, mem(r8, 7*64)) // y[i+112] - y[i+127] = zmm7 + + vmovups(mem(rdx, 8*64), zmm8) // zmm8 = x[i+128] - x[i+143] + vmovups(zmm8, mem(r8, 8*64)) // y[i+128] - y[i+143] = zmm8 + vmovups(mem(rdx, 9*64), zmm9) // zmm9 = x[i+144] - x[i+159] + vmovups(zmm9, mem(r8, 9*64)) // y[i+144] - y[i+159] = zmm9 + vmovups(mem(rdx, 10*64), zmm10) // zmm10 = x[i+160] - x[i+175] + vmovups(zmm10, mem(r8, 10*64)) // y[i+160] - y[i+175] = zmm10 + vmovups(mem(rdx, 11*64), zmm11) // zmm11 = x[i+176] - x[i+191] + vmovups(zmm11, mem(r8, 11*64)) // y[i+176] - y[i+191] = zmm11 + + vmovups(mem(rdx, 12*64), zmm12) // zmm12 = x[i+192] - x[i+207] + vmovups(zmm12, mem(r8, 12*64)) // y[i+192] - y[i+207] = zmm12 + vmovups(mem(rdx, 13*64), zmm13) // zmm13 = x[i+208] - x[i+223] + vmovups(zmm13, mem(r8, 13*64)) // y[i+208] - y[i+223] = zmm13 + vmovups(mem(rdx, 14*64), zmm14) // zmm14 = x[i+224] - x[i+239] + vmovups(zmm14, mem(r8, 14*64)) // y[i+224] - y[i+239] = zmm14 + vmovups(mem(rdx, 15*64), zmm15) // zmm15 = x[i+240] - x[i+255] + vmovups(zmm15, mem(r8, 15*64)) // y[i+240] - y[i+255] = zmm15 + + // Increment the pointer + add(imm(16*4*16), rdx) + add(imm(16*4*16), r8) + sub(imm(16*16), rsi) // reduce the number of remaining elements by 256 + + cmp(imm(16*16), rsi) + jge(.MAINLOOP) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 128 elements + label(.BLOCK128) + + cmp(imm(16*8), rsi) // check if the number of remaining elements greater than or equal to 128 + jl(.BLOCK64) // else, goto to the section of code for block of size 64 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovups(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+15] + vmovups(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+15] = zmm0 + vmovups(mem(rdx, 1*64), zmm1) // zmm1 = x[i+16] - x[i+31] + vmovups(zmm1, mem(r8, 1*64)) // y[i+16] - y[i+31] = zmm1 + vmovups(mem(rdx, 2*64), zmm2) // zmm2 = x[i+32] - x[i+47] + vmovups(zmm2, mem(r8, 2*64)) // y[i+32] - y[i+47] = zmm2 + vmovups(mem(rdx, 3*64), zmm3) // zmm3 = x[i+48] - x[i+63] + vmovups(zmm3, mem(r8, 3*64)) // y[i+48] - y[i+63] = zmm3 + + vmovups(mem(rdx, 4*64), zmm4) // zmm4 = x[i+64] - x[i+79] + vmovups(zmm4, mem(r8, 4*64)) // y[i+64] - y[i+79] = zmm4 + vmovups(mem(rdx, 5*64), zmm5) // zmm5 = x[i+80] - x[i+95] + vmovups(zmm5, mem(r8, 5*64)) // y[i+80] - y[i+95] = zmm5 + vmovups(mem(rdx, 6*64), zmm6) // zmm6 = x[i+96] - x[i+111] + vmovups(zmm6, mem(r8, 6*64)) // y[i+96] - y[i+111] = zmm6 + vmovups(mem(rdx, 7*64), zmm7) // zmm7 = x[i+112] - x[i+127] + vmovups(zmm7, mem(r8, 7*64)) // y[i+112] - y[i+127] = zmm7 + + // Increment the pointer + add(imm(16*4*8), rdx) + add(imm(16*4*8), r8) + sub(imm(16*8), rsi) // reduce the number of remaining elements by 128 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 64 elements + label(.BLOCK64) + + cmp(imm(16*4), rsi) // check if the number of remaining elements greater than or equal to 64 + jl(.BLOCK32) // else, goto to the section of code for block of size 32 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovups(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+15] + vmovups(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+15] = zmm0 + vmovups(mem(rdx, 1*64), zmm1) // zmm1 = x[i+16] - x[i+31] + vmovups(zmm1, mem(r8, 1*64)) // y[i+16] - y[i+31] = zmm1 + vmovups(mem(rdx, 2*64), zmm2) // zmm2 = x[i+32] - x[i+47] + vmovups(zmm2, mem(r8, 2*64)) // y[i+32] - y[i+47] = zmm2 + vmovups(mem(rdx, 3*64), zmm3) // zmm3 = x[i+48] - x[i+63] + vmovups(zmm3, mem(r8, 3*64)) // y[i+48] - y[i+63] = zmm3 + + // Increment the pointer + add(imm(16*4*4), rdx) + add(imm(16*4*4), r8) + sub(imm(16*4), rsi) // reduce the number of remaining elements by 64 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 32 elements + label(.BLOCK32) + + cmp(imm(16*2), rsi) // check if the number of remaining elements greater than or equal to 32 + jl(.BLOCK16) // else, goto to the section of code for block of size 16 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovups(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+15] + vmovups(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+15] = zmm0 + vmovups(mem(rdx, 1*64), zmm1) // zmm1 = x[i+16] - x[i+31] + vmovups(zmm1, mem(r8, 1*64)) // y[i+16] - y[i+31] = zmm1 + + add(imm(16*4*2), rdx) + add(imm(16*4*2), r8) + sub(imm(16*2), rsi) // reduce the number of remaining elements by 32 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 16 elements + label(.BLOCK16) + + cmp(imm(16), rsi) // check if the number of remaining elements greater than or equal to 16 + jl(.FRINGE) // else, goto to the section of code for fringe cases + + // Loading and storing the values to destination + + vmovups(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+15] + vmovups(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+15] = zmm0 + + // Increment the pointer + add(imm(16*4), rdx) + add(imm(16*4), r8) + sub(imm(16), rsi) // reduce the number of remaining elements by 16 + + // ----------------------------------------------------------- + + // Section of code to deal with fringe cases + label(.FRINGE) + + cmp(imm(0), rsi) // check if there is any fringe cases + je(.END) + + // Creating a 8-bit mask + mov(imm(65535), rcx) // (65535)BASE_10 -> (1111 1111 1111 1111)BASE_2 + shlx(rsi,rcx,rcx) // shifting the bits in the register to the left depending on the number of fringe elements remaining + xor(imm(65535),rcx) // taking compliment of the register + kmovq(rcx, k(2)) // copying the value in the register to mask register + + /* + Creating mask: Example - fringe case = 2 + step 1 : rdx = (1111 1111 1111 1111)BASE_2 or (65535)BASE_10 + step 2 : rdx = (1111 1111 1111 1100)BASE_2 or (65532)BASE_10 + step 3 : rdx = (0000 0000 0000 0011)BASE_2 or (3)BASE_10 + */ + + // Loading the input values using masked load + vmovups(mem(rdx, 0*64), zmm0 MASK_(K(2))) + + // Storing the values to destination using masked store + vmovups(zmm0, mem(r8) MASK_(K(2))) + + // After the above instructions are executed, the remaining part are not executed + jmp(.END) + + // ======================================================================================================================== + + // Code section used to deal with situations where incx or incy is not 1 + label(.SCALAR) + + // incx and incy are multipled by 8 (shift left by 2 bits) and stored back into their respective registers + mov(imm(2), r11) + shlx(r11, rcx, rcx) + shlx(r11, r9, r9) + + // A loop is used to move one element at a time to the destination + label(.SCALARLOOP) + + // checking if all the elements are moved, then the loop will be terminated + cmp(imm(0), rsi) + je(.END) + + // Using vector register to mov one element at a time + vmovss(mem(rdx, 0), xmm0) + vmovss(xmm0, mem(r8, 0)) + + // Moving the address pointer of x and y array by incx*8 and incy*8 bytes + add(rcx, rdx) + add(r9, r8) + + dec(rsi) + jmp(.SCALARLOOP) + + label(.END) + end_asm( + : + : [n0] "m" (n0), + [x0] "m" (x0), + [incx0] "m" (incx0), + [y0] "m" (y0), + [incy0] "m" (incy0) + + : "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", + "zmm8", "zmm9", "zmm10", "zmm11", + "zmm12", "zmm13", "zmm14", "zmm15", + "xmm0", "rsi", "rdx", "rcx", + "r8", "r9", "r11", "k2", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) +} + + +// -------------------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a vector x to a vector y for + type double. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_dcopyv_zen4_asm_avx512 +( + conj_t conjx, + dim_t n, + double* restrict x, dim_t incx, + double* restrict y, dim_t incy, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + + // Initialize local pointers. + double *x0 = x; + double *y0 = y; + + // Typecast int to 64 bit + uint64_t n0 = (uint64_t)n; + int64_t incy0 = (int64_t)incy; + int64_t incx0 = (int64_t)incx; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + // assembly code + begin_asm() + + /* + rcx - > n + rsi - > x + r8 - > incx + rdi - > y + r9 - > incy + */ + + // Loading the source and destination memory addresses into the respective registers + mov(var(x0), rsi) + mov(var(y0), rdi) + + // Loading the values in n, incx and inxy into the respective registers + mov(var(n0), rcx) + mov(var(incx0), r8 ) + mov(var(incy0), r9 ) + + // Checking if incx == 1 and incy == 1, incase the condition fails then SCALAR code section is executed + cmp(imm(1), r8) + jne(.SCALAR) + cmp(imm(1),r9) + jne(.SCALAR) + +// ========================================================================================================================== + + // Section of code to move the data as blocks of 128 elements + label(.BLOCK128) + + cmp(imm(8*16), rcx) // Check if the number of remaining elements greater than or equal to 128 -> (NUMBER OF ELEMENTS PER REGISTER) * (NUMBER OF REGISTERS USED IN THE BLOCK) + jl(.BLOCK64) // Else, skip the BLOCK128 section and goto to BLOCK64 section of the code + + label(.MAINLOOP) + + // Interleaved SIMD load and store operations to copy data from source to the destination + // Each vector register can hold 8 elements and is used twice before next jump operation + // 1 vmovupd for loading the element from source and 1 vmovupd for store it into the destination + + vmovupd(mem(rsi, 0*64), zmm0) // zmm0 = x[i+0] - x[i+7] + vmovupd(zmm0, mem(rdi, 0*64)) // y[i+0] - y[i+7] = zmm0 + vmovupd(mem(rsi, 1*64), zmm1) // zmm1 = x[i+8] - x[i+15] + vmovupd(zmm1, mem(rdi, 1*64)) // y[i+8] - y[i+15] = zmm1 + vmovupd(mem(rsi, 2*64), zmm2) // zmm2 = x[i+16] - x[i+23] + vmovupd(zmm2, mem(rdi, 2*64)) // y[i+16] - y[i+23] = zmm2 + vmovupd(mem(rsi, 3*64), zmm3) // zmm3 = x[i+24] - x[i+31] + vmovupd(zmm3, mem(rdi, 3*64)) // y[i+24] - y[i+31] = zmm3 + + vmovupd(mem(rsi, 4*64), zmm4) // zmm4 = x[i+32] - x[i+39] + vmovupd(zmm4, mem(rdi, 4*64)) // y[i+32] - y[i+39] = zmm4 + vmovupd(mem(rsi, 5*64), zmm5) // zmm5 = x[i+40] - x[i+47] + vmovupd(zmm5, mem(rdi, 5*64)) // y[i+40] - y[i+47] = zmm5 + vmovupd(mem(rsi, 6*64), zmm6) // zmm6 = x[i+48] - x[i+55] + vmovupd(zmm6, mem(rdi, 6*64)) // y[i+48] - y[i+55] = zmm6 + vmovupd(mem(rsi, 7*64), zmm7) // zmm7 = x[i+56] - x[i+63] + vmovupd(zmm7, mem(rdi, 7*64)) // y[i+56] - y[i+63] = zmm7 + + vmovupd(mem(rsi, 8*64), zmm8) // zmm8 = x[i+64] - x[i+71] + vmovupd(zmm8, mem(rdi, 8*64)) // y[i+64] - y[i+71] = zmm8 + vmovupd(mem(rsi, 9*64), zmm9) // zmm9 = x[i+72] - x[i+79] + vmovupd(zmm9, mem(rdi, 9*64)) // y[i+72] - y[i+79] = zmm9 + vmovupd(mem(rsi, 10*64), zmm10) // zmm10 = x[i+80] - x[i+87] + vmovupd(zmm10, mem(rdi, 10*64)) // y[i+80] - y[i+87] = zmm10 + vmovupd(mem(rsi, 11*64), zmm11) // zmm11 = x[i+88] - x[i+95] + vmovupd(zmm11, mem(rdi, 11*64)) // y[i+88] - y[i+95] = zmm11 + + vmovupd(mem(rsi, 12*64), zmm12) // zmm12 = x[i+96] - x[i+103] + vmovupd(zmm12, mem(rdi, 12*64)) // y[i+96] - y[i+103] = zmm12 + vmovupd(mem(rsi, 13*64), zmm13) // zmm13 = x[i+104] - x[i+111] + vmovupd(zmm13, mem(rdi, 13*64)) // y[i+104] - y[i+111] = zmm13 + vmovupd(mem(rsi, 14*64), zmm14) // zmm14 = x[i+112] - x[i+119] + vmovupd(zmm14, mem(rdi, 14*64)) // y[i+112] - y[i+119] = zmm14 + vmovupd(mem(rsi, 15*64), zmm15) // zmm15 = x[i+120] - x[i+127] + vmovupd(zmm15, mem(rdi, 15*64)) // y[i+120] - y[i+127] = zmm15 + + // Increment the pointer + add(imm(8*8*16), rsi) // Increment the x0 pointer by 1024 -> ( Size of double datatype ) * ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + add(imm(8*8*16), rdi) // Increment the y0 pointer by 1024 + sub(imm(8*16), rcx) // reduce the number of remaining elements by 128 -> ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + + // Jump back to the Main loop if the number of remaning elements are still greater than 128 + cmp(imm(8*16), rcx) + jge(.MAINLOOP) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 64 elements + label(.BLOCK64) + + cmp(imm(8*8), rcx) // Check if the number of remaining elements greater than or equal to 64 + jl(.BLOCK32) // Else, skip the BLOCK64 section and goto to BLOCK32 section of the code + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rsi, 0*64), zmm0) // zmm0 = x[i+0] - x[i+7] + vmovupd(zmm0, mem(rdi, 0*64)) // y[i+0] - y[i+7] = zmm0 + vmovupd(mem(rsi, 1*64), zmm1) // zmm1 = x[i+8] - x[i+15] + vmovupd(zmm1, mem(rdi, 1*64)) // y[i+8] - y[i+15] = zmm1 + vmovupd(mem(rsi, 2*64), zmm2) // zmm2 = x[i+16] - x[i+23] + vmovupd(zmm2, mem(rdi, 2*64)) // y[i+16] - y[i+23] = zmm2 + vmovupd(mem(rsi, 3*64), zmm3) // zmm3 = x[i+24] - x[i+31] + vmovupd(zmm3, mem(rdi, 3*64)) // y[i+24] - y[i+31] = zmm3 + + vmovupd(mem(rsi, 4*64), zmm4) // zmm4 = x[i+32] - x[i+39] + vmovupd(zmm4, mem(rdi, 4*64)) // y[i+32] - y[i+39] = zmm4 + vmovupd(mem(rsi, 5*64), zmm5) // zmm5 = x[i+40] - x[i+47] + vmovupd(zmm5, mem(rdi, 5*64)) // y[i+40] - y[i+47] = zmm5 + vmovupd(mem(rsi, 6*64), zmm6) // zmm6 = x[i+48] - x[i+55] + vmovupd(zmm6, mem(rdi, 6*64)) // y[i+48] - y[i+55] = zmm6 + vmovupd(mem(rsi, 7*64), zmm7) // zmm7 = x[i+56] - x[i+63] + vmovupd(zmm7, mem(rdi, 7*64)) // y[i+56] - y[i+63] = zmm7 + + // Increment the pointer + add(imm(8*8*8), rsi) // Increment the x0 pointer by 512 + add(imm(8*8*8), rdi) // Increment the y0 pointer by 512 + sub(imm(8*8), rcx) // reduce the number of remaining elements by 64 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 32 elements + label(.BLOCK32) + + cmp(imm(8*4), rcx) // check if the number of remaining elements greater than or equal to 32 + jl(.BLOCK16) // Else, skip the BLOCK32 section and goto to BLOCK16 section of the code + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rsi, 0*64), zmm0) // zmm0 = x[i+0] - x[i+7] + vmovupd(zmm0, mem(rdi, 0*64)) // y[i+0] - y[i+7] = zmm0 + vmovupd(mem(rsi, 1*64), zmm1) // zmm1 = x[i+8] - x[i+15] + vmovupd(zmm1, mem(rdi, 1*64)) // y[i+8] - y[i+15] = zmm1 + vmovupd(mem(rsi, 2*64), zmm2) // zmm2 = x[i+16] - x[i+23] + vmovupd(zmm2, mem(rdi, 2*64)) // y[i+16] - y[i+23] = zmm2 + vmovupd(mem(rsi, 3*64), zmm3) // zmm3 = x[i+24] - x[i+31] + vmovupd(zmm3, mem(rdi, 3*64)) // y[i+24] - y[i+31] = zmm3 + + // Increment the pointer + add(imm(8*8*4), rsi) // Increment the x0 pointer by 256 + add(imm(8*8*4), rdi) // Increment the y0 pointer by 256 + sub(imm(8*4), rcx) // reduce the number of remaining elements by 32 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 16 elements + label(.BLOCK16) + + cmp(imm(8*2), rcx) // check if the number of remaining elements greater than or equal to 16 + jl(.BLOCK8) // else, skip the BLOCK16 section and goto to BLOCK8 section of the code + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rsi, 0*64), zmm0) // zmm0 = x[i+0] - x[i+7] + vmovupd(zmm0, mem(rdi, 0*64)) // y[i+0] - y[i+7] = zmm0 + vmovupd(mem(rsi, 1*64), zmm1) // zmm1 = x[i+8] - x[i+15] + vmovupd(zmm1, mem(rdi, 1*64)) // y[i+8] - y[i+15] = zmm1 + + // Increment the pointer + add(imm(8*8*2), rsi) // Increment the x0 pointer by 128 + add(imm(8*8*2), rdi) // Increment the y0 pointer by 128 + sub(imm(8*2), rcx) // reduce the number of remaining elements by 16 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 8 elements + label(.BLOCK8) + + cmp(imm(8), rcx) // check if the number of remaining elements greater than or equal to 8 + jl(.FRINGE) // else, skip the BLOCK8 section and goto to FRINGE section of the code + + // Load and store operations to copy data from source to the destination + + vmovupd(mem(rsi, 0*64), zmm0) // zmm0 = x[i+0] - x[i+7] + vmovupd(zmm0, mem(rdi, 0*64)) // y[i+0] - y[i+7] = zmm0 + + // Increment the pointer + add(imm(8*8), rsi) // Increment the x0 pointer by 64 + add(imm(8*8), rdi) // Increment the y0 pointer by 64 + sub(imm(8), rcx) // reduce the number of remaining elements by 8 + + // ----------------------------------------------------------- + + // Section of code to deal with fringe cases + label(.FRINGE) + + cmp(imm(0), rcx) // Check if there are any fringe cases + je(.END) // Else, skip rest of the code + + // Creating a 8-bit mask + mov(imm(255), r8) // (255)10 -> (1111 1111)2 + shlx(rcx, r8, r8) // shifting the bits in the register to the left depending on the number of fringe elements remaining + xor(imm(255), r8) // taking compliment of the register + + // Copying the 8-bit mask in the register to mask register + kmovq(r8, k(2)) + + /* + Creating mask: Example - fringe case = 2 + step 1 : r8 = (1111 1111)2 or (255)10 + step 2 : r8 = (1111 1100)2 or (252)10 + step 3 : r8 = (0000 0011)2 or (3)10 + */ + + // Loading the input values using masked load + vmovupd(mem(rsi), zmm0 MASK_(K(2))) + + // Storing the values to destination using masked store + vmovupd(zmm0, mem(rdi) MASK_(K(2))) + + // Multiple the value of remaining elements by 8 + mov(imm(3), r11) // Load the value 3 to r11 register + shlx(r11, rcx, r11) // Left-Shift the value in rcx by 8 + + // Increment the pointer + add(r11, rsi) // Increment the x0 pointer by (Number of remaining elements * 8) + add(r11, rdi) // Increment the y0 pointer by (Number of remaining elements * 8) + xor(rcx, rcx) // Set the value of remaining elements to 0 + + // After the above instructions are executed, the remaining part are skipped + jmp(.END) + + // ======================================================================================================================== + + // Code section used to deal with situations where incx or incy is not 1 + label(.SCALAR) + + // incx and incy are multipled by 8 (shift left by 3 bits) and stored back into their respective registers + mov(imm(3), r11) + shlx(r11, r8, r8) + shlx(r11, r9, r9) + + // A loop is used to move one element at a time to the destination + label(.SCALARLOOP) + + // Checking if all the elements are moved, then the loop will be terminated + cmp(imm(0), rcx) + je(.END) + + // Using vector register to mov one element at a time + vmovsd(mem(rsi, 0), xmm0) + vmovsd(xmm0, mem(rdi, 0)) + + // Moving the address pointer of x and y array by incx*8 and incy*8 bytes + add(r8, rsi) + add(r9, rdi) + + // Decrease the count for number of remaining elements + dec(rcx) + + // Jump back to SCALARLOOP + jmp(.SCALARLOOP) + + label(.END) + end_asm( + : + : [n0] "m" (n0), + [x0] "m" (x0), + [incx0] "m" (incx0), + [y0] "m" (y0), + [incy0] "m" (incy0) + + : "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", + "zmm8", "zmm9", "zmm10", "zmm11", + "zmm12", "zmm13", "zmm14", "zmm15", + "rsi", "rdi", "rcx", "r8", + "r9", "r11", "k2", "xmm0", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) +} + +// ----------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a double complex vector x to a double complex vector y. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_zcopyv_zen4_asm_avx512 +( + conj_t conjx, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + + // Initialize local pointers. + dcomplex *x0 = x; + dcomplex *y0 = y; + + // Typecast int to 64 bit + uint64_t n0 = (uint64_t)n; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + if (bli_is_conj(conjx)) + { + if (incx == 1 && incy == 1) + { + // assembly code + begin_asm() + + /* + rdi - > conjx + rsi - > n + rdx - > x + rcx - > incx + r8 - > y + r9 - > incy + */ + + // Loading the source memory address to respective registers + mov(var(x0), rdx) + mov(var(y0), r8) + + // Loading the value of 'n' into rsi register + mov(var(n0), rsi) + + // Setting the value of zmm16 to zero + vxorpd(zmm16, zmm16, zmm16) + + // =========================================================== + + // Section of code to move the data as blocks of 64 elements + label(.BLOCK64) + + cmp(imm(4*16), rsi) // check if the number of remaining elements greater than or equal to 64 + jl(.BLOCK32) // else, goto to the section of code for block of size 32 + + label(.MAINLOOP) + // Interleaved SIMD load, conjugate and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vfmsubadd231pd(zmm16, zmm16, zmm0) // zmm0 = conj(zmm0) + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vfmsubadd231pd(zmm16, zmm16, zmm1) // zmm1 = conj(zmm1) + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vfmsubadd231pd(zmm16, zmm16, zmm2) // zmm2 = conj(zmm2) + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vfmsubadd231pd(zmm16, zmm16, zmm3) // zmm3 = conj(zmm3) + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + vmovupd(mem(rdx, 4*64), zmm4) // zmm4 = x[i+16] - x[i+19] + vfmsubadd231pd(zmm16, zmm16, zmm4) // zmm4 = conj(zmm4) + vmovupd(zmm4, mem(r8, 4*64)) // y[i+16] - y[i+19] = zmm4 + vmovupd(mem(rdx, 5*64), zmm5) // zmm5 = x[i+20] - x[i+23] + vfmsubadd231pd(zmm16, zmm16, zmm5) // zmm5 = conj(zmm5) + vmovupd(zmm5, mem(r8, 5*64)) // y[i+20] - y[i+23] = zmm5 + vmovupd(mem(rdx, 6*64), zmm6) // zmm6 = x[i+24] - x[i+27] + vfmsubadd231pd(zmm16, zmm16, zmm6) // zmm6 = conj(zmm6) + vmovupd(zmm6, mem(r8, 6*64)) // y[i+24] - y[i+27] = zmm6 + vmovupd(mem(rdx, 7*64), zmm7) // zmm7 = x[i+28] - x[i+31] + vfmsubadd231pd(zmm16, zmm16, zmm7) // zmm7 = conj(zmm7) + vmovupd(zmm7, mem(r8, 7*64)) // y[i+28] - y[i+31] = zmm7 + + vmovupd(mem(rdx, 8*64), zmm8) // zmm8 = x[i+32] - x[i+35] + vfmsubadd231pd(zmm16, zmm16, zmm8) // zmm8 = conj(zmm8) + vmovupd(zmm8, mem(r8, 8*64)) // y[i+32] - y[i+35] = zmm8 + vmovupd(mem(rdx, 9*64), zmm9) // zmm9 = x[i+36] - x[i+39] + vfmsubadd231pd(zmm16, zmm16, zmm9) // zmm9 = conj(zmm9) + vmovupd(zmm9, mem(r8, 9*64)) // y[i+36] - y[i+39] = zmm9 + vmovupd(mem(rdx, 10*64), zmm10) // zmm10 = x[i+40] - x[i+43] + vfmsubadd231pd(zmm16, zmm16, zmm10) // zmm10 = conj(zmm10) + vmovupd(zmm10, mem(r8, 10*64)) // y[i+40] - y[i+43] = zmm10 + vmovupd(mem(rdx, 11*64), zmm11) // zmm11 = x[i+44] - x[i+47] + vfmsubadd231pd(zmm16, zmm16, zmm11) // zmm11 = conj(zmm11) + vmovupd(zmm11, mem(r8, 11*64)) // y[i+44] - y[i+47] = zmm11 + + vmovupd(mem(rdx, 12*64), zmm12) // zmm12 = x[i+48] - x[i+51] + vfmsubadd231pd(zmm16, zmm16, zmm12) // zmm12 = conj(zmm12) + vmovupd(zmm12, mem(r8, 12*64)) // y[i+48] - y[i+51] = zmm12 + vmovupd(mem(rdx, 13*64), zmm13) // zmm13 = x[i+52] - x[i+55] + vfmsubadd231pd(zmm16, zmm16, zmm13) // zmm13 = conj(zmm13) + vmovupd(zmm13, mem(r8, 13*64)) // y[i+52] - y[i+55] = zmm13 + vmovupd(mem(rdx, 14*64), zmm14) // zmm14 = x[i+56] - x[i+59] + vfmsubadd231pd(zmm16, zmm16, zmm14) // zmm14 = conj(zmm14) + vmovupd(zmm14, mem(r8, 14*64)) // y[i+56] - y[i+59] = zmm14 + vmovupd(mem(rdx, 15*64), zmm15) // zmm15 = x[i+60] - x[i+63] + vfmsubadd231pd(zmm16, zmm16, zmm15) // zmm15 = conj(zmm15) + vmovupd(zmm15, mem(r8, 15*64)) // y[i+60] - y[i+63] = zmm15 + + // Increment the pointer + add(imm(16*4*16), rdx) // ( Size of double datatype ) * ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + add(imm(16*4*16), r8) + sub(imm(4*16), rsi) // reduce the number of remaining elements by 64 -> ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + + cmp(imm(4*16), rsi) + jge(.MAINLOOP) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 32 elements + label(.BLOCK32) + + cmp(imm(4*8), rsi) // check if the number of remaining elements greater than or equal to 32 + jl(.BLOCK16) // else, goto to the section of code for block of size 16 + + // Interleaved SIMD load, conjugate and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vfmsubadd231pd(zmm16, zmm16, zmm0) // zmm0 = conj(zmm0) + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vfmsubadd231pd(zmm16, zmm16, zmm1) // zmm1 = conj(zmm1) + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vfmsubadd231pd(zmm16, zmm16, zmm2) // zmm2 = conj(zmm2) + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vfmsubadd231pd(zmm16, zmm16, zmm3) // zmm3 = conj(zmm3) + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + vmovupd(mem(rdx, 4*64), zmm4) // zmm4 = x[i+16] - x[i+19] + vfmsubadd231pd(zmm16, zmm16, zmm4) // zmm4 = conj(zmm4) + vmovupd(zmm4, mem(r8, 4*64)) // y[i+16] - y[i+19] = zmm4 + vmovupd(mem(rdx, 5*64), zmm5) // zmm5 = x[i+20] - x[i+23] + vfmsubadd231pd(zmm16, zmm16, zmm5) // zmm5 = conj(zmm5) + vmovupd(zmm5, mem(r8, 5*64)) // y[i+20] - y[i+23] = zmm5 + vmovupd(mem(rdx, 6*64), zmm6) // zmm6 = x[i+24] - x[i+27] + vfmsubadd231pd(zmm16, zmm16, zmm6) // zmm6 = conj(zmm6) + vmovupd(zmm6, mem(r8, 6*64)) // y[i+24] - y[i+27] = zmm6 + vmovupd(mem(rdx, 7*64), zmm7) // zmm7 = x[i+28] - x[i+31] + vfmsubadd231pd(zmm16, zmm16, zmm7) // zmm7 = conj(zmm7) + vmovupd(zmm7, mem(r8, 7*64)) // y[i+28] - y[i+31] = zmm7 + + // Increment the pointer + add(imm(16*4*8), rdx) + add(imm(16*4*8), r8) + sub(imm(4*8), rsi) // reduce the number of remaining elements by 32 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 16 elements + label(.BLOCK16) + + cmp(imm(4*4), rsi) // check if the number of remaining elements greater than or equal to 16 + jl(.BLOCK8) // else, goto to the section of code for block of size 8 + + // Interleaved SIMD load, conjugate and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vfmsubadd231pd(zmm16, zmm16, zmm0) // zmm0 = conj(zmm0) + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vfmsubadd231pd(zmm16, zmm16, zmm1) // zmm1 = conj(zmm1) + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vfmsubadd231pd(zmm16, zmm16, zmm2) // zmm2 = conj(zmm2) + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vfmsubadd231pd(zmm16, zmm16, zmm3) // zmm3 = conj(zmm3) + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + // Increment the pointer + add(imm(16*4*4), rdx) + add(imm(16*4*4), r8) + sub(imm(4*4), rsi) // reduce the number of remaining elements by 16 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 8 elements + label(.BLOCK8) + + cmp(imm(4*2), rsi) // check if the number of remaining elements greater than or equal to 8 + jl(.BLOCK4) // else, goto to the section of code for block of size 4 + + // Interleaved SIMD load, conjugate and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vfmsubadd231pd(zmm16, zmm16, zmm0) // zmm0 = conj(zmm0) + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vfmsubadd231pd(zmm16, zmm16, zmm1) // zmm1 = conj(zmm1) + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + + // Increment the pointer + add(imm(16*4*2), rdx) + add(imm(16*4*2), r8) + sub(imm(4*2), rsi) // reduce the number of remaining elements by 8 + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 4 elements + label(.BLOCK4) + + cmp(imm(4), rsi) // check if the number of remaining elements greater than or equal to 4 + jl(.FRINGE) // else, goto to the section of code that deals with fringe cases + + // Load, conjugate and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vfmsubadd231pd(zmm16, zmm16, zmm0) // zmm0 = conj(zmm0) + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + + // Increment the pointer + add(imm(16*4), rdx) + add(imm(16*4), r8) + sub(imm(4), rsi) // reduce the number of remaining elements by 4 + + // ----------------------------------------------------------- + + // Section of code to deal with fringe cases + label(.FRINGE) + + cmp(imm(0), rsi) // check if there is any fringe cases + je(.END) + + // Creating a 8-bit mask + mov(imm(255), rcx) // (255)10 -> (1111 1111)2 + shlx(rsi, rcx, rcx) // shifting the bits in the register to the left depending on the number of fringe elements remaining + shlx(rsi, rcx, rcx) + xor(imm(255),rcx) // taking compliment of the register + kmovq(rcx, k(2)) // copying the value in the register to mask register + + /* + Creating mask: Example - fringe case = 1 + step 1 : rcx = (1111 1111)2 or (255)10 + step 2 : rcx = (1111 1110)2 or (254)10 + step 3 : rcx = (1111 1100)2 or (252)10 + step 4 : rcx = (0000 0011)2 or (3)10 + */ + // Loading the input values using masked load + vmovupd(mem(rdx, 0*64), zmm0 MASK_(K(2))) + + // Using Fused Multiply-AlternatingAdd/Subtract operation to get conjugate of the input + vfmsubadd231pd(zmm16, zmm16, zmm0) + + // Storing the values to destination using masked store + vmovupd(zmm0, mem(r8) MASK_(K(2))) + + // Increment the pointer + add(rsi, rdx) + add(rsi, r8) + and(imm(0), rsi) + + label(.END) + end_asm( + : + : [n0] "m" (n0), + [x0] "m" (x0), + [y0] "m" (y0) + + : "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", + "zmm8", "zmm9", "zmm10", "zmm11", + "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "rsi", "rdx", "rcx", + "r8", "r9", "k2", "memory" + ) + } + else + { + // Since double complex elements are of size 128 bits, + // vectorization can be done using XMM registers when incx and incy are not 1. + // This is done in the else condition. + dim_t i = 0; + __m128d xv[16]; + __m128d zero_reg = _mm_setzero_pd(); + + // n & (~0x0F) = n & 0xFFFFFFF0 -> this masks the numbers less than 16, + // if value of n < 16, then (n & (~0x0F)) = 0 + // the copy operation will be done for the multiples of 16 + for ( i = 0; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[0]); + xv[1] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[1]); + xv[2] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[2]); + xv[3] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[3]); + + xv[4] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[4]); + xv[5] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[5]); + xv[6] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[6]); + xv[7] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[7]); + + xv[8] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[8]); + xv[9] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[9]); + xv[10] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[10]); + xv[11] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[11]); + + xv[12] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[12]); + xv[13] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[13]); + xv[14] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[14]); + xv[15] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[15]); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9 ), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + // Increment the pointer + x0 += 16 * incx; + y0 += 16 * incy; + } + + for ( ; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[0]); + xv[1] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[1]); + xv[2] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[2]); + xv[3] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[3]); + + xv[4] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[4]); + xv[5] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[5]); + xv[6] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[6]); + xv[7] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[7]); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + // Increment the pointer + x0 += 8 * incx; + y0 += 8 * incy; + } + + for ( ; i < (n & (~0x03)); i += 4) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[0]); + xv[1] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[1]); + xv[2] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[2]); + xv[3] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[3]); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + // Increment the pointer + x0 += 4 * incx; + y0 += 4 * incy; + } + + for ( ; i < (n & (~0x01)); i += 2) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[0]); + xv[1] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[1]); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + + // Increment the pointer + x0 += 2 * incx; + y0 += 2 * incy; + } + + for ( ; i < n; i += 1) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_fmsubadd_pd(zero_reg, zero_reg, xv[0]); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + + // Increment the pointer + x0 += 1 * incx; + y0 += 1 * incy; + } + } + } + else + { + if (incx == 1 && incy == 1) + { + // assembly code + begin_asm() + + /* + rdi - > conjx + rsi - > n + rdx - > x + rcx - > incx + r8 - > y + r9 - > incy + */ + + // Loading the source memory address to respective registers + mov(var(x0), rdx) + mov(var(y0), r8) + + // Loading the value of 'n' to respective register + mov(var(n0), rsi) + + // =========================================================== + + // Section of code to move the data as blocks of 128 elements + label(.BLOCK128) + + cmp(imm(4*32), rsi) // check if the number of remaining elements greater than or equal to 128 -> (NUMBER OF ELEMENTS PER REGISTER) * (NUMBER OF REGISTERS USED IN THE BLOCK) + jl(.BLOCK64) // else, goto block of size 64 + + label(.MAINLOOP) + // Interleaved SIMD load and store operations to copy data from source to the destination + // Each vector register can hold 4 elements and is used twice before next jump operation + // 1 for loading the element from source and 1 for store it into the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + vmovupd(mem(rdx, 4*64), zmm4) // zmm4 = x[i+16] - x[i+19] + vmovupd(zmm4, mem(r8, 4*64)) // y[i+16] - y[i+19] = zmm4 + vmovupd(mem(rdx, 5*64), zmm5) // zmm5 = x[i+20] - x[i+23] + vmovupd(zmm5, mem(r8, 5*64)) // y[i+20] - y[i+23] = zmm5 + vmovupd(mem(rdx, 6*64), zmm6) // zmm6 = x[i+24] - x[i+27] + vmovupd(zmm6, mem(r8, 6*64)) // y[i+24] - y[i+27] = zmm6 + vmovupd(mem(rdx, 7*64), zmm7) // zmm7 = x[i+28] - x[i+31] + vmovupd(zmm7, mem(r8, 7*64)) // y[i+28] - y[i+31] = zmm7 + + vmovupd(mem(rdx, 8*64), zmm8) // zmm8 = x[i+32] - x[i+35] + vmovupd(zmm8, mem(r8, 8*64)) // y[i+32] - y[i+35] = zmm8 + vmovupd(mem(rdx, 9*64), zmm9) // zmm9 = x[i+36] - x[i+39] + vmovupd(zmm9, mem(r8, 9*64)) // y[i+36] - y[i+39] = zmm9 + vmovupd(mem(rdx, 10*64), zmm10) // zmm10 = x[i+40] - x[i+43] + vmovupd(zmm10, mem(r8, 10*64)) // y[i+40] - y[i+43] = zmm10 + vmovupd(mem(rdx, 11*64), zmm11) // zmm11 = x[i+44] - x[i+47] + vmovupd(zmm11, mem(r8, 11*64)) // y[i+44] - y[i+47] = zmm11 + + vmovupd(mem(rdx, 12*64), zmm12) // zmm12 = x[i+48] - x[i+51] + vmovupd(zmm12, mem(r8, 12*64)) // y[i+48] - y[i+51] = zmm12 + vmovupd(mem(rdx, 13*64), zmm13) // zmm13 = x[i+52] - x[i+55] + vmovupd(zmm13, mem(r8, 13*64)) // y[i+52] - y[i+55] = zmm13 + vmovupd(mem(rdx, 14*64), zmm14) // zmm14 = x[i+56] - x[i+59] + vmovupd(zmm14, mem(r8, 14*64)) // y[i+56] - y[i+59] = zmm14 + vmovupd(mem(rdx, 15*64), zmm15) // zmm15 = x[i+60] - x[i+63] + vmovupd(zmm15, mem(r8, 15*64)) // y[i+60] - y[i+63] = zmm15 + + vmovupd(mem(rdx, 16*64), zmm16) // zmm16 = x[i+64] - x[i+67] + vmovupd(zmm16, mem(r8, 16*64)) // y[i+64] - y[i+67] = zmm16 + vmovupd(mem(rdx, 17*64), zmm17) // zmm17 = x[i+68] - x[i+71] + vmovupd(zmm17, mem(r8, 17*64)) // y[i+68] - y[i+71] = zmm17 + vmovupd(mem(rdx, 18*64), zmm18) // zmm18 = x[i+72] - x[i+75] + vmovupd(zmm18, mem(r8, 18*64)) // y[i+72] - y[i+75] = zmm18 + vmovupd(mem(rdx, 19*64), zmm19) // zmm19 = x[i+76] - x[i+79] + vmovupd(zmm19, mem(r8, 19*64)) // y[i+76] - y[i+79] = zmm19 + + vmovupd(mem(rdx, 20*64), zmm20) // zmm20 = x[i+80] - x[i+83] + vmovupd(zmm20, mem(r8, 20*64)) // y[i+80] - y[i+83] = zmm20 + vmovupd(mem(rdx, 21*64), zmm21) // zmm21 = x[i+84] - x[i+87] + vmovupd(zmm21, mem(r8, 21*64)) // y[i+84] - y[i+87] = zmm21 + vmovupd(mem(rdx, 22*64), zmm22) // zmm22 = x[i+88] - x[i+91] + vmovupd(zmm22, mem(r8, 22*64)) // y[i+88] - y[i+91] = zmm22 + vmovupd(mem(rdx, 23*64), zmm23) // zmm23 = x[i+92] - x[i+95] + vmovupd(zmm23, mem(r8, 23*64)) // y[i+92] - y[i+95] = zmm23 + + vmovupd(mem(rdx, 24*64), zmm24) // zmm24 = x[i+96] - x[i+99] + vmovupd(zmm24, mem(r8, 24*64)) // y[i+96] - y[i+99] = zmm24 + vmovupd(mem(rdx, 25*64), zmm25) // zmm25 = x[i+100] - x[i+103] + vmovupd(zmm25, mem(r8, 25*64)) // y[i+100] - y[i+103] = zmm25 + vmovupd(mem(rdx, 26*64), zmm26) // zmm26 = x[i+104] - x[i+107] + vmovupd(zmm26, mem(r8, 26*64)) // y[i+104] - y[i+107] = zmm26 + vmovupd(mem(rdx, 27*64), zmm27) // zmm27 = x[i+108] - x[i+111] + vmovupd(zmm27, mem(r8, 27*64)) // y[i+108] - y[i+111] = zmm27 + + vmovupd(mem(rdx, 28*64), zmm28) // zmm28 = x[i+112] - x[i+115] + vmovupd(zmm28, mem(r8, 28*64)) // y[i+112] - y[i+115] = zmm28 + vmovupd(mem(rdx, 29*64), zmm29) // zmm29 = x[i+116] - x[i+119] + vmovupd(zmm29, mem(r8, 29*64)) // y[i+116] - y[i+119] = zmm29 + vmovupd(mem(rdx, 30*64), zmm30) // zmm30 = x[i+120] - x[i+123] + vmovupd(zmm30, mem(r8, 30*64)) // y[i+120] - y[i+123] = zmm30 + vmovupd(mem(rdx, 31*64), zmm31) // zmm31 = x[i+124] - x[i+127] + vmovupd(zmm31, mem(r8, 31*64)) // y[i+124] - y[i+127] = zmm31 + + // Increment the pointer + add(imm(16*4*32), rdx) // ( Size of double datatype ) * ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + add(imm(16*4*32), r8) + + // reduce the number of remaining elements by 128 + sub(imm(4*32), rsi) // ( Number of elements per register ) * ( Number of zmm registers used in the section of code ) + + cmp(imm(4*32), rsi) + jge(.MAINLOOP) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 64 elements + label(.BLOCK64) + + cmp(imm(4*16), rsi) // check if the number of remaining elements greater than or equal to 64 + jl(.BLOCK32) // else, goto to the section of code for block of size 32 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + vmovupd(mem(rdx, 4*64), zmm4) // zmm4 = x[i+16] - x[i+19] + vmovupd(zmm4, mem(r8, 4*64)) // y[i+16] - y[i+19] = zmm4 + vmovupd(mem(rdx, 5*64), zmm5) // zmm5 = x[i+20] - x[i+23] + vmovupd(zmm5, mem(r8, 5*64)) // y[i+20] - y[i+23] = zmm5 + vmovupd(mem(rdx, 6*64), zmm6) // zmm6 = x[i+24] - x[i+27] + vmovupd(zmm6, mem(r8, 6*64)) // y[i+24] - y[i+27] = zmm6 + vmovupd(mem(rdx, 7*64), zmm7) // zmm7 = x[i+28] - x[i+31] + vmovupd(zmm7, mem(r8, 7*64)) // y[i+28] - y[i+31] = zmm7 + + vmovupd(mem(rdx, 8*64), zmm8) // zmm8 = x[i+32] - x[i+35] + vmovupd(zmm8, mem(r8, 8*64)) // y[i+32] - y[i+35] = zmm8 + vmovupd(mem(rdx, 9*64), zmm9) // zmm9 = x[i+36] - x[i+39] + vmovupd(zmm9, mem(r8, 9*64)) // y[i+36] - y[i+39] = zmm9 + vmovupd(mem(rdx, 10*64), zmm10) // zmm10 = x[i+40] - x[i+43] + vmovupd(zmm10, mem(r8, 10*64)) // y[i+40] - y[i+43] = zmm10 + vmovupd(mem(rdx, 11*64), zmm11) // zmm11 = x[i+44] - x[i+47] + vmovupd(zmm11, mem(r8, 11*64)) // y[i+44] - y[i+47] = zmm11 + + vmovupd(mem(rdx, 12*64), zmm12) // zmm12 = x[i+48] - x[i+51] + vmovupd(zmm12, mem(r8, 12*64)) // y[i+48] - y[i+51] = zmm12 + vmovupd(mem(rdx, 13*64), zmm13) // zmm13 = x[i+52] - x[i+55] + vmovupd(zmm13, mem(r8, 13*64)) // y[i+52] - y[i+55] = zmm13 + vmovupd(mem(rdx, 14*64), zmm14) // zmm14 = x[i+56] - x[i+59] + vmovupd(zmm14, mem(r8, 14*64)) // y[i+56] - y[i+59] = zmm14 + vmovupd(mem(rdx, 15*64), zmm15) // zmm15 = x[i+60] - x[i+63] + vmovupd(zmm15, mem(r8, 15*64)) // y[i+60] - y[i+63] = zmm15 + + // Increment the pointer + add(imm(16*4*16), rdx) + add(imm(16*4*16), r8) + + // reduce the number of remaining elements by 64 + sub(imm(4*16), rsi) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 32 elements + label(.BLOCK32) + + cmp(imm(4*8), rsi) // check if the number of remaining elements greater than or equal to 32 + jl(.BLOCK16) // else, goto to the section of code for block of size 16 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + vmovupd(mem(rdx, 4*64), zmm4) // zmm4 = x[i+16] - x[i+19] + vmovupd(zmm4, mem(r8, 4*64)) // y[i+16] - y[i+19] = zmm4 + vmovupd(mem(rdx, 5*64), zmm5) // zmm5 = x[i+20] - x[i+23] + vmovupd(zmm5, mem(r8, 5*64)) // y[i+20] - y[i+23] = zmm5 + vmovupd(mem(rdx, 6*64), zmm6) // zmm6 = x[i+24] - x[i+27] + vmovupd(zmm6, mem(r8, 6*64)) // y[i+24] - y[i+27] = zmm6 + vmovupd(mem(rdx, 7*64), zmm7) // zmm7 = x[i+28] - x[i+31] + vmovupd(zmm7, mem(r8, 7*64)) // y[i+28] - y[i+31] = zmm7 + + // Increment the pointer + add(imm(16*4*8), rdx) + add(imm(16*4*8), r8) + + // reduce the number of remaining elements by 32 + sub(imm(4*8), rsi) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 16 elements + label(.BLOCK16) + + cmp(imm(4*4), rsi) // check if the number of remaining elements greater than or equal to 16 + jl(.BLOCK8) // else, goto to the section of code for block of size 8 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + vmovupd(mem(rdx, 2*64), zmm2) // zmm2 = x[i+8] - x[i+11] + vmovupd(zmm2, mem(r8, 2*64)) // y[i+8] - y[i+11] = zmm2 + vmovupd(mem(rdx, 3*64), zmm3) // zmm3 = x[i+12] - x[i+15] + vmovupd(zmm3, mem(r8, 3*64)) // y[i+12] - y[i+15] = zmm3 + + // Increment the pointer + add(imm(16*4*4), rdx) + add(imm(16*4*4), r8) + + // reduce the number of remaining elements by 16 + sub(imm(4*4), rsi) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 8 elements + label(.BLOCK8) + + cmp(imm(4*2), rsi) // check if the number of remaining elements greater than or equal to 8 + jl(.BLOCK4) // else, goto to the section of code for block of size 4 + + // Interleaved SIMD load and store operations to copy data from source to the destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + vmovupd(mem(rdx, 1*64), zmm1) // zmm1 = x[i+4] - x[i+7] + vmovupd(zmm1, mem(r8, 1*64)) // y[i+4] - y[i+7] = zmm1 + + // Increment the pointer + add(imm(16*4*2), rdx) + add(imm(16*4*2), r8) + + // reduce the number of remaining elements by 8 + sub(imm(4*2), rsi) + + // ----------------------------------------------------------- + + // Section of code to move the data as blocks of 4 elements + label(.BLOCK4) + + cmp(imm(4), rsi) // check if the number of remaining elements greater than or equal to 4 + jl(.FRINGE) // else, goto to the section of code that deals with fringe cases + + // Loading and storing the values to destination + + vmovupd(mem(rdx, 0*64), zmm0) // zmm0 = x[i+0] - x[i+3] + vmovupd(zmm0, mem(r8, 0*64)) // y[i+0] - y[i+3] = zmm0 + + // Increment the pointer + add(imm(16*4), rdx) + add(imm(16*4), r8) + + // reduce the number of remaining elements by 4 + sub(imm(4), rsi) + + // ----------------------------------------------------------- + + // Section of code to deal with fringe cases + label(.FRINGE) + + cmp(imm(0), rsi) // check if there is any fringe cases + je(.END) + + // Creating a 8-bit mask + mov(imm(255), rcx) // (255)10 -> (1111 1111)2 + shlx(rsi,rcx,rcx) // shifting the bits in the register to the left depending on the number of fringe elements remaining + shlx(rsi,rcx,rcx) + xor(imm(255),rcx) // taking compliment of the register + kmovq(rcx, k(2)) // copying the value in the register to mask register + + /* + Creating mask: Example - fringe case = 1 + step 1 : rcx = (1111 1111)2 or (255)10 + step 2 : rcx = (1111 1110)2 or (254)10 + step 3 : rcx = (1111 1100)2 or (252)10 + step 4 : rcx = (0000 0011)2 or (3)10 + */ + // Loading the input values using masked load + vmovupd(mem(rdx, 0*64), zmm0 MASK_(K(2))) + + // Storing the values to destination using masked store + vmovupd(zmm0, mem(r8) MASK_(K(2))) + + // Increment the pointer + add(rsi, rdx) + add(rsi, r8) + and(imm(0), rsi) + + label(.END) + end_asm( + : + : [n0] "m" (n0), + [x0] "m" (x0), + [y0] "m" (y0) + + : "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", + "zmm8", "zmm9", "zmm10", "zmm11", + "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", + "zmm24", "zmm25", "zmm26", "zmm27", + "zmm28", "zmm29", "zmm30", "zmm31", + "rsi", "rdx", "rcx", "r8", + "r9", "k2", "memory" + ) + } + else + { + // Since double complex elements are of size 128 bits, + // vectorization can be done using XMM registers when incx and incy are not 1. + // This is done in the else condition. + __m128d xv[32]; + dim_t i = 0; + + // n & (~0x1F) = n & 0xFFFFFFE0 -> this masks the numbers less than 32, + // if value of n < 32, then (n & (~0x1F)) = 0 + // the copy operation will be done for the multiples of 32 + for ( i = 0; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + xv[16] = _mm_loadu_pd((double *)(x0 + 16 * incx)); + xv[17] = _mm_loadu_pd((double *)(x0 + 17 * incx)); + xv[18] = _mm_loadu_pd((double *)(x0 + 18 * incx)); + xv[19] = _mm_loadu_pd((double *)(x0 + 19 * incx)); + + xv[20] = _mm_loadu_pd((double *)(x0 + 20 * incx)); + xv[21] = _mm_loadu_pd((double *)(x0 + 21 * incx)); + xv[22] = _mm_loadu_pd((double *)(x0 + 22 * incx)); + xv[23] = _mm_loadu_pd((double *)(x0 + 23 * incx)); + + xv[24] = _mm_loadu_pd((double *)(x0 + 24 * incx)); + xv[25] = _mm_loadu_pd((double *)(x0 + 25 * incx)); + xv[26] = _mm_loadu_pd((double *)(x0 + 26 * incx)); + xv[27] = _mm_loadu_pd((double *)(x0 + 27 * incx)); + + xv[28] = _mm_loadu_pd((double *)(x0 + 28 * incx)); + xv[29] = _mm_loadu_pd((double *)(x0 + 29 * incx)); + xv[30] = _mm_loadu_pd((double *)(x0 + 30 * incx)); + xv[31] = _mm_loadu_pd((double *)(x0 + 31 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + _mm_storeu_pd((double *)(y0 + incy * 16), xv[16]); + _mm_storeu_pd((double *)(y0 + incy * 17), xv[17]); + _mm_storeu_pd((double *)(y0 + incy * 18), xv[18]); + _mm_storeu_pd((double *)(y0 + incy * 19), xv[19]); + + _mm_storeu_pd((double *)(y0 + incy * 20), xv[20]); + _mm_storeu_pd((double *)(y0 + incy * 21), xv[21]); + _mm_storeu_pd((double *)(y0 + incy * 22), xv[22]); + _mm_storeu_pd((double *)(y0 + incy * 23), xv[23]); + + _mm_storeu_pd((double *)(y0 + incy * 24), xv[24]); + _mm_storeu_pd((double *)(y0 + incy * 25), xv[25]); + _mm_storeu_pd((double *)(y0 + incy * 26), xv[26]); + _mm_storeu_pd((double *)(y0 + incy * 27), xv[27]); + + _mm_storeu_pd((double *)(y0 + incy * 28), xv[28]); + _mm_storeu_pd((double *)(y0 + incy * 29), xv[29]); + _mm_storeu_pd((double *)(y0 + incy * 30), xv[30]); + _mm_storeu_pd((double *)(y0 + incy * 31), xv[31]); + + // Increment the pointer + x0 += 32 * incx; + y0 += 32 * incy; + } + + for ( ; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + // Increment the pointer + x0 += 16 * incx; + y0 += 16 * incy; + } + + for ( ; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + // Increment the pointer + x0 += 8 * incx; + y0 += 8 * incy; + } + + for ( ; i < (n & (~0x03)); i += 4) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + // Increment the pointer + x0 += 4 * incx; + y0 += 4 * incy; + } + + for ( ; i < (n & (~0x01)); i += 2) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + + // Increment the pointer + x0 += 2 * incx; + y0 += 2 * incy; + } + + for ( ; i < n; i += 1) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + + // Increment the pointer + x0 += 1 * incx; + y0 += 1 * incy; + } + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) +} diff --git a/kernels/zen4/1/bli_copyv_zen_int_avx512.c b/kernels/zen4/1/bli_copyv_zen_int_avx512.c new file mode 100644 index 0000000000..6aed74cd1b --- /dev/null +++ b/kernels/zen4/1/bli_copyv_zen_int_avx512.c @@ -0,0 +1,1578 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +// -------------------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a vector x to a vector y for + type float. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Float pointer pointing to an array + * 'y' - Float pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_scopyv_zen_int_avx512 +( + conj_t conjx, + dim_t n, + float* restrict x, inc_t incx, + float* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + dim_t i = 0; + + // Initialize local pointers. + float *restrict x0 = x; + float *restrict y0 = y; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + if (incx == 1 && incy == 1) + { + const dim_t num_elem_per_reg = 16; + __m512 xv[32]; + + // n & (~0x1FF) = n & 0xFFFFFE00 -> this masks the numbers less than 512, + // if value of n < 512, then (n & (~0xFF)) = 0 + // the copy operation will be done for the multiples of 512 + for (i = 0; i < (n & (~0x1FF)); i += 512) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_ps(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_ps(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_ps(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_ps(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_ps(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_ps(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_ps(x0 + num_elem_per_reg * 7); + + xv[8] = _mm512_loadu_ps(x0 + num_elem_per_reg * 8); + xv[9] = _mm512_loadu_ps(x0 + num_elem_per_reg * 9); + xv[10] = _mm512_loadu_ps(x0 + num_elem_per_reg * 10); + xv[11] = _mm512_loadu_ps(x0 + num_elem_per_reg * 11); + + xv[12] = _mm512_loadu_ps(x0 + num_elem_per_reg * 12); + xv[13] = _mm512_loadu_ps(x0 + num_elem_per_reg * 13); + xv[14] = _mm512_loadu_ps(x0 + num_elem_per_reg * 14); + xv[15] = _mm512_loadu_ps(x0 + num_elem_per_reg * 15); + + xv[16] = _mm512_loadu_ps(x0 + num_elem_per_reg * 16); + xv[17] = _mm512_loadu_ps(x0 + num_elem_per_reg * 17); + xv[18] = _mm512_loadu_ps(x0 + num_elem_per_reg * 18); + xv[19] = _mm512_loadu_ps(x0 + num_elem_per_reg * 19); + + xv[20] = _mm512_loadu_ps(x0 + num_elem_per_reg * 20); + xv[21] = _mm512_loadu_ps(x0 + num_elem_per_reg * 21); + xv[22] = _mm512_loadu_ps(x0 + num_elem_per_reg * 22); + xv[23] = _mm512_loadu_ps(x0 + num_elem_per_reg * 23); + + xv[24] = _mm512_loadu_ps(x0 + num_elem_per_reg * 24); + xv[25] = _mm512_loadu_ps(x0 + num_elem_per_reg * 25); + xv[26] = _mm512_loadu_ps(x0 + num_elem_per_reg * 26); + xv[27] = _mm512_loadu_ps(x0 + num_elem_per_reg * 27); + + xv[28] = _mm512_loadu_ps(x0 + num_elem_per_reg * 28); + xv[29] = _mm512_loadu_ps(x0 + num_elem_per_reg * 29); + xv[30] = _mm512_loadu_ps(x0 + num_elem_per_reg * 30); + xv[31] = _mm512_loadu_ps(x0 + num_elem_per_reg * 31); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 7, xv[7]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 8, xv[8]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 9 , xv[9]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 10, xv[10]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 11, xv[11]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 12, xv[12]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 13, xv[13]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 14, xv[14]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 15, xv[15]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 16, xv[16]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 17, xv[17]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 18, xv[18]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 19, xv[19]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 20, xv[20]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 21, xv[21]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 22, xv[22]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 23, xv[23]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 24, xv[24]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 25, xv[25]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 26, xv[26]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 27, xv[27]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 28, xv[28]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 29, xv[29]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 30, xv[30]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 31, xv[31]); + + // Increment the pointer + x0 += 32 * num_elem_per_reg; + y0 += 32 * num_elem_per_reg; + } + + for (; i < (n & (~0xFF)); i += 256) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_ps(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_ps(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_ps(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_ps(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_ps(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_ps(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_ps(x0 + num_elem_per_reg * 7); + + xv[8] = _mm512_loadu_ps(x0 + num_elem_per_reg * 8); + xv[9] = _mm512_loadu_ps(x0 + num_elem_per_reg * 9); + xv[10] = _mm512_loadu_ps(x0 + num_elem_per_reg * 10); + xv[11] = _mm512_loadu_ps(x0 + num_elem_per_reg * 11); + + xv[12] = _mm512_loadu_ps(x0 + num_elem_per_reg * 12); + xv[13] = _mm512_loadu_ps(x0 + num_elem_per_reg * 13); + xv[14] = _mm512_loadu_ps(x0 + num_elem_per_reg * 14); + xv[15] = _mm512_loadu_ps(x0 + num_elem_per_reg * 15); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 7, xv[7]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 8, xv[8]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 9 , xv[9]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 10, xv[10]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 11, xv[11]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 12, xv[12]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 13, xv[13]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 14, xv[14]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 15, xv[15]); + + // Increment the pointer + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for (; i < (n & (~0x7F)); i += 128) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_ps(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_ps(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_ps(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_ps(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_ps(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_ps(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_ps(x0 + num_elem_per_reg * 7); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_ps(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 7, xv[7]); + + // Increment the pointer + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for (; i < (n & (~0x3F)); i += 64) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_ps(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_ps(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_ps(x0 + num_elem_per_reg * 3); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 3, xv[3]); + + // Increment the pointer + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for (; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_ps(x0 + num_elem_per_reg * 1); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_ps(y0 + num_elem_per_reg * 1, xv[1]); + + // Increment the pointer + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for (; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm512_loadu_ps(x0 + num_elem_per_reg * 0); + + // Storing the values to destination + _mm512_storeu_ps(y0 + num_elem_per_reg * 0, xv[0]); + + // Increment the pointer + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + + if ( i < n ) + { + xv[1] = _mm512_setzero_ps(); + + // Creating the mask + __mmask16 mask = (1 << (n-i)) - 1; + + // Loading the input values + xv[0] = _mm512_mask_loadu_ps(xv[1], mask, x0 + num_elem_per_reg * 0); + + // Storing the values to destination + _mm512_mask_storeu_ps(y0 + num_elem_per_reg * 0, mask, xv[0]); + + } + } + else + { + for ( i = 0; i < n; ++i) + { + *y0 = *x0; + + x0 += incx; + y0 += incy; + } + } +} + + +// -------------------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a vector x to a vector y for + type double. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_dcopyv_zen_int_avx512 +( + conj_t conjx, + dim_t n, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + dim_t i = 0; + + // Initialize local pointers. + double *restrict x0 = x; + double *restrict y0 = y; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + if (incx == 1 && incy == 1) + { + const dim_t num_elem_per_reg = 8; + __m512d xv[32]; + + // n & (~0x7F) = n & 0xFFFFF00 -> this masks the numbers less than 256, + // if value of n < 256, then (n & (~0xFF)) = 0 + // the copy operation will be done for the multiples of 256 + for (i = 0; i < (n & (~0xFF)); i += 256) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_pd(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_pd(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_pd(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_pd(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_pd(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_pd(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_pd(x0 + num_elem_per_reg * 7); + + xv[8] = _mm512_loadu_pd(x0 + num_elem_per_reg * 8); + xv[9] = _mm512_loadu_pd(x0 + num_elem_per_reg * 9); + xv[10] = _mm512_loadu_pd(x0 + num_elem_per_reg * 10); + xv[11] = _mm512_loadu_pd(x0 + num_elem_per_reg * 11); + + xv[12] = _mm512_loadu_pd(x0 + num_elem_per_reg * 12); + xv[13] = _mm512_loadu_pd(x0 + num_elem_per_reg * 13); + xv[14] = _mm512_loadu_pd(x0 + num_elem_per_reg * 14); + xv[15] = _mm512_loadu_pd(x0 + num_elem_per_reg * 15); + + xv[16] = _mm512_loadu_pd(x0 + num_elem_per_reg * 16); + xv[17] = _mm512_loadu_pd(x0 + num_elem_per_reg * 17); + xv[18] = _mm512_loadu_pd(x0 + num_elem_per_reg * 18); + xv[19] = _mm512_loadu_pd(x0 + num_elem_per_reg * 19); + + xv[20] = _mm512_loadu_pd(x0 + num_elem_per_reg * 20); + xv[21] = _mm512_loadu_pd(x0 + num_elem_per_reg * 21); + xv[22] = _mm512_loadu_pd(x0 + num_elem_per_reg * 22); + xv[23] = _mm512_loadu_pd(x0 + num_elem_per_reg * 23); + + xv[24] = _mm512_loadu_pd(x0 + num_elem_per_reg * 24); + xv[25] = _mm512_loadu_pd(x0 + num_elem_per_reg * 25); + xv[26] = _mm512_loadu_pd(x0 + num_elem_per_reg * 26); + xv[27] = _mm512_loadu_pd(x0 + num_elem_per_reg * 27); + + xv[28] = _mm512_loadu_pd(x0 + num_elem_per_reg * 28); + xv[29] = _mm512_loadu_pd(x0 + num_elem_per_reg * 29); + xv[30] = _mm512_loadu_pd(x0 + num_elem_per_reg * 30); + xv[31] = _mm512_loadu_pd(x0 + num_elem_per_reg * 31); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 7, xv[7]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 8, xv[8]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 9 , xv[9]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 10, xv[10]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 11, xv[11]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 12, xv[12]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 13, xv[13]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 14, xv[14]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 15, xv[15]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 16, xv[16]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 17, xv[17]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 18, xv[18]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 19, xv[19]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 20, xv[20]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 21, xv[21]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 22, xv[22]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 23, xv[23]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 24, xv[24]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 25, xv[25]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 26, xv[26]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 27, xv[27]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 28, xv[28]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 29, xv[29]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 30, xv[30]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 31, xv[31]); + + // Increment the pointer + x0 += 32 * num_elem_per_reg; + y0 += 32 * num_elem_per_reg; + } + + for (; i < (n & (~0x7F)); i += 128) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_pd(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_pd(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_pd(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_pd(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_pd(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_pd(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_pd(x0 + num_elem_per_reg * 7); + + xv[8] = _mm512_loadu_pd(x0 + num_elem_per_reg * 8); + xv[9] = _mm512_loadu_pd(x0 + num_elem_per_reg * 9); + xv[10] = _mm512_loadu_pd(x0 + num_elem_per_reg * 10); + xv[11] = _mm512_loadu_pd(x0 + num_elem_per_reg * 11); + + xv[12] = _mm512_loadu_pd(x0 + num_elem_per_reg * 12); + xv[13] = _mm512_loadu_pd(x0 + num_elem_per_reg * 13); + xv[14] = _mm512_loadu_pd(x0 + num_elem_per_reg * 14); + xv[15] = _mm512_loadu_pd(x0 + num_elem_per_reg * 15); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 7, xv[7]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 8, xv[8]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 9 , xv[9]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 10, xv[10]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 11, xv[11]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 12, xv[12]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 13, xv[13]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 14, xv[14]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 15, xv[15]); + + // Increment the pointer + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for (; i < (n & (~0x3F)); i += 64) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_pd(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_pd(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_pd(x0 + num_elem_per_reg * 3); + + xv[4] = _mm512_loadu_pd(x0 + num_elem_per_reg * 4); + xv[5] = _mm512_loadu_pd(x0 + num_elem_per_reg * 5); + xv[6] = _mm512_loadu_pd(x0 + num_elem_per_reg * 6); + xv[7] = _mm512_loadu_pd(x0 + num_elem_per_reg * 7); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 3, xv[3]); + + _mm512_storeu_pd(y0 + num_elem_per_reg * 4, xv[4]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 5, xv[5]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 6, xv[6]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 7, xv[7]); + + // Increment the pointer + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for (; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_pd(x0 + num_elem_per_reg * 1); + xv[2] = _mm512_loadu_pd(x0 + num_elem_per_reg * 2); + xv[3] = _mm512_loadu_pd(x0 + num_elem_per_reg * 3); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 1, xv[1]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 2, xv[2]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 3, xv[3]); + + // Increment the pointer + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for (; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + xv[1] = _mm512_loadu_pd(x0 + num_elem_per_reg * 1); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + _mm512_storeu_pd(y0 + num_elem_per_reg * 1, xv[1]); + + // Increment the pointer + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for (; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm512_loadu_pd(x0 + num_elem_per_reg * 0); + + // Storing the values to destination + _mm512_storeu_pd(y0 + num_elem_per_reg * 0, xv[0]); + + // Increment the pointer + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + + if ( i < n ) + { + xv[1] = _mm512_setzero_pd(); + + // Creating the mask + __mmask8 mask = (1 << (n-i)) - 1; + + // Loading the input values + xv[0] = _mm512_mask_loadu_pd(xv[1], mask, x0 + num_elem_per_reg * 0); + + // Storing the values to destination + _mm512_mask_storeu_pd(y0 + num_elem_per_reg * 0, mask, xv[0]); + + } + } + else + { + for ( i = 0; i < n; ++i) + { + *y0 = *x0; + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) +} + +// ----------------------------------------------------------------------------- + +/* + Functionality + ------------- + + This function copies a double complex vector x to a double complex vector y. + + y := conj?(x) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n < 0, incx < 1 and incy < 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ + +void bli_zcopyv_zen_int_avx512 +( + conj_t conjx, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) + dim_t i = 0; + + // Initialize local pointers. + dcomplex *x0 = x; + dcomplex *y0 = y; + + // If the vector dimension is zero return early. + if (bli_zero_dim1(n)) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) + return; + } + + // Check if conjugation is required and select the required code path + if (bli_is_conj(conjx)) + { + + if (incx == 1 && incy == 1) + { + const dim_t num_elem_per_reg = 8; + __m512d xv[16]; + __m512d zero_reg = _mm512_setzero_pd(); + + // n & (~0x3F) = n & 0xFFFFFFC0 -> this masks the numbers less than 64, + // if value of n < 64, then (n & (~0x3F)) = 0 + // the copy operation will be done for the multiples of 64 + for (i = 0; i < (n & (~0x3F)); i += 64) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 3)); + + xv[4] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 4)); + xv[5] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 5)); + xv[6] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 6)); + xv[7] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 7)); + + xv[8] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 8)); + xv[9] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 9)); + xv[10] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 10)); + xv[11] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 11)); + + xv[12] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 12)); + xv[13] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 13)); + xv[14] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 14)); + xv[15] = _mm512_loadu_pd((double *)(x0 + num_elem_per_reg * 15)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + xv[1] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[1]); + xv[2] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[2]); + xv[3] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[3]); + + xv[4] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[4]); + xv[5] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[5]); + xv[6] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[6]); + xv[7] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[7]); + + xv[8] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[8]); + xv[9] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[9]); + xv[10] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[10]); + xv[11] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[11]); + + xv[12] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[12]); + xv[13] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[13]); + xv[14] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[14]); + xv[15] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[15]); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 4), xv[4]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 5), xv[5]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 6), xv[6]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 7), xv[7]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 8), xv[8]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 9), xv[9]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 10), xv[10]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 11), xv[11]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 12), xv[12]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 13), xv[13]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 14), xv[14]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 15), xv[15]); + + // Increment the pointer + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for (; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + xv[4] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 4)); + xv[5] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 5)); + xv[6] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 6)); + xv[7] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 7)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + xv[1] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[1]); + xv[2] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[2]); + xv[3] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[3]); + + xv[4] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[4]); + xv[5] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[5]); + xv[6] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[6]); + xv[7] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[7]); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 4), xv[4]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 5), xv[5]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 6), xv[6]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 7), xv[7]); + + // Increment the pointer + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for (; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + xv[1] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[1]); + xv[2] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[2]); + xv[3] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[3]); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + // Increment the pointer + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for (; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + xv[1] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[1]); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + + // Increment the pointer + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for (; i < (n & (~0x03)); i += 4) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + + // Increment the pointer + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + + if ( i < n ) + { + xv[1] = _mm512_setzero_pd(); + + // Creating the mask + __mmask8 mask = (1 << 2*(n-i)) - 1; + + // Loading the input values + xv[0] = _mm512_mask_loadu_pd( zero_reg, mask,(double *)( x0 + num_elem_per_reg * 0)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm512_fmsubadd_pd( zero_reg, zero_reg, xv[0]); + + // Storing the values to destination + _mm512_mask_storeu_pd((double *)(y0 + num_elem_per_reg * 0), mask, xv[0]); + + } + } + else + { + // Since double complex elements are of size 128 bits, + // vectorization can be done using XMM registers when incx and incy are not 1. + // This is done in the else condition. + __m128d xv[16]; + __m128d conj_reg = _mm_setr_pd(1, -1); + + // n & (~0x0F) = n & 0xFFFFFFF0 -> this masks the numbers less than 16, + // if value of n < 16, then (n & (~0x0F)) = 0 + // the copy operation will be done for the multiples of 16 + for ( i = 0; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_mul_pd(xv[0], conj_reg); + xv[1] = _mm_mul_pd(xv[1], conj_reg); + xv[2] = _mm_mul_pd(xv[2], conj_reg); + xv[3] = _mm_mul_pd(xv[3], conj_reg); + + xv[4] = _mm_mul_pd(xv[4], conj_reg); + xv[5] = _mm_mul_pd(xv[5], conj_reg); + xv[6] = _mm_mul_pd(xv[6], conj_reg); + xv[7] = _mm_mul_pd(xv[7], conj_reg); + + xv[8] = _mm_mul_pd(xv[8], conj_reg); + xv[9] = _mm_mul_pd(xv[9], conj_reg); + xv[10] = _mm_mul_pd(xv[10], conj_reg); + xv[11] = _mm_mul_pd(xv[11], conj_reg); + + xv[12] = _mm_mul_pd(xv[12], conj_reg); + xv[13] = _mm_mul_pd(xv[13], conj_reg); + xv[14] = _mm_mul_pd(xv[14], conj_reg); + xv[15] = _mm_mul_pd(xv[15], conj_reg); + + // Storing the values to destination + + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9 ), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + // Increment the pointer + x0 += 16 * incx; + y0 += 16 * incy; + } + + for ( ; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_mul_pd(xv[0], conj_reg); + xv[1] = _mm_mul_pd(xv[1], conj_reg); + xv[2] = _mm_mul_pd(xv[2], conj_reg); + xv[3] = _mm_mul_pd(xv[3], conj_reg); + + xv[4] = _mm_mul_pd(xv[4], conj_reg); + xv[5] = _mm_mul_pd(xv[5], conj_reg); + xv[6] = _mm_mul_pd(xv[6], conj_reg); + xv[7] = _mm_mul_pd(xv[7], conj_reg); + + // Storing the values to destination + + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + // Increment the pointer + x0 += 8 * incx; + y0 += 8 * incy; + } + + for ( ; i < (n & (~0x03)); i += 4) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_mul_pd(xv[0], conj_reg); + xv[1] = _mm_mul_pd(xv[1], conj_reg); + xv[2] = _mm_mul_pd(xv[2], conj_reg); + xv[3] = _mm_mul_pd(xv[3], conj_reg); + + // Storing the values to destination + + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + // Increment the pointer + x0 += 4 * incx; + y0 += 4 * incy; + } + + for ( ; i < (n & (~0x01)); i += 2) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_mul_pd(xv[0], conj_reg); + xv[1] = _mm_mul_pd(xv[1], conj_reg); + + // Storing the values to destination + + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + + // Increment the pointer + x0 += 2 * incx; + y0 += 2 * incy; + } + + for ( ; i < n; i += 1) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + + // Perform conjugation by multiplying the imaginary part with -1 and real part with 1 + xv[0] = _mm_mul_pd(xv[0], conj_reg); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + + // Increment the pointer + x0 += 1 * incx; + y0 += 1 * incy; + } + } + } + else + { + if (incx == 1 && incy == 1) + { + const dim_t num_elem_per_reg = 8; + __m512d xv[32]; + + // n & (~0xFF) = n & 0xFFFFFF00 -> this masks the numbers less than 256, + // if value of n < 256, then (n & (~0xFF)) = 0 + // the copy operation will be done for the multiples of 256 + for (i = 0; i < (n & (~0xFF)); i += 256) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + xv[4] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 4)); + xv[5] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 5)); + xv[6] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 6)); + xv[7] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 7)); + + xv[8] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 8)); + xv[9] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 9)); + xv[10] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 10)); + xv[11] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 11)); + + xv[12] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 12)); + xv[13] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 13)); + xv[14] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 14)); + xv[15] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 15)); + + xv[16] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 16)); + xv[17] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 17)); + xv[18] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 18)); + xv[19] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 19)); + + xv[20] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 20)); + xv[21] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 21)); + xv[22] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 22)); + xv[23] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 23)); + + xv[24] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 24)); + xv[25] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 25)); + xv[26] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 26)); + xv[27] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 27)); + + xv[28] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 28)); + xv[29] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 29)); + xv[30] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 30)); + xv[31] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 31)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 4), xv[4]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 5), xv[5]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 6), xv[6]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 7), xv[7]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 8), xv[8]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 9), xv[9]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 10), xv[10]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 11), xv[11]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 12), xv[12]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 13), xv[13]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 14), xv[14]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 15), xv[15]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 16), xv[16]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 17), xv[17]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 18), xv[18]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 19), xv[19]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 20), xv[20]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 21), xv[21]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 22), xv[22]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 23), xv[23]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 24), xv[24]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 25), xv[25]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 26), xv[26]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 27), xv[27]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 28), xv[28]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 29), xv[29]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 30), xv[30]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 31), xv[31]); + + // Increment the pointer + x0 += 32 * num_elem_per_reg; + y0 += 32 * num_elem_per_reg; + } + + for (; i < (n & (~0x7F)); i += 128) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + xv[4] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 4)); + xv[5] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 5)); + xv[6] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 6)); + xv[7] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 7)); + + xv[8] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 8)); + xv[9] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 9)); + xv[10] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 10)); + xv[11] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 11)); + + xv[12] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 12)); + xv[13] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 13)); + xv[14] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 14)); + xv[15] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 15)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 4), xv[4]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 5), xv[5]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 6), xv[6]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 7), xv[7]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 8), xv[8]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 9), xv[9]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 10), xv[10]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 11), xv[11]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 12), xv[12]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 13), xv[13]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 14), xv[14]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 15), xv[15]); + + // Increment the pointer + x0 += 16 * num_elem_per_reg; + y0 += 16 * num_elem_per_reg; + } + + for (; i < (n & (~0x3F)); i += 64) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + xv[4] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 4)); + xv[5] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 5)); + xv[6] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 6)); + xv[7] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 7)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 4), xv[4]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 5), xv[5]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 6), xv[6]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 7), xv[7]); + + // Increment the pointer + x0 += 8 * num_elem_per_reg; + y0 += 8 * num_elem_per_reg; + } + + for (; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + xv[2] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 2)); + xv[3] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 3)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 2), xv[2]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 3), xv[3]); + + // Increment the pointer + x0 += 4 * num_elem_per_reg; + y0 += 4 * num_elem_per_reg; + } + + for (; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + xv[1] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 1)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 1), xv[1]); + + // Increment the pointer + x0 += 2 * num_elem_per_reg; + y0 += 2 * num_elem_per_reg; + } + + for (; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm512_loadu_pd((double *)(x0+ num_elem_per_reg * 0)); + + // Storing the values to destination + _mm512_storeu_pd((double *)(y0 + num_elem_per_reg * 0), xv[0]); + + // Increment the pointer + x0 += num_elem_per_reg; + y0 += num_elem_per_reg; + } + + if ( i < n ) + { + xv[1] = _mm512_setzero_pd(); + + // Creating the mask + __mmask8 mask = (1 << 2*(n-i)) - 1; + + // Loading the input values + xv[0] = _mm512_mask_loadu_pd(xv[1], mask, (double *)(x0 + num_elem_per_reg * 0)); + + // Storing the values to destination + _mm512_mask_storeu_pd((double *)(y0 + num_elem_per_reg * 0), mask, xv[0]); + + } + } + else + { + // Since double complex elements are of size 128 bits, + // vectorization can be done using XMM registers when incx and incy are not 1. + // This is done in the else condition. + __m128d xv[32]; + + // n & (~0x1F) = n & 0xFFFFFFE0 -> this masks the numbers less than 32, + // if value of n < 32, then (n & (~0x1F)) = 0 + // the copy operation will be done for the multiples of 32 + for ( i = 0; i < (n & (~0x1F)); i += 32) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + xv[16] = _mm_loadu_pd((double *)(x0 + 16 * incx)); + xv[17] = _mm_loadu_pd((double *)(x0 + 17 * incx)); + xv[18] = _mm_loadu_pd((double *)(x0 + 18 * incx)); + xv[19] = _mm_loadu_pd((double *)(x0 + 19 * incx)); + + xv[20] = _mm_loadu_pd((double *)(x0 + 20 * incx)); + xv[21] = _mm_loadu_pd((double *)(x0 + 21 * incx)); + xv[22] = _mm_loadu_pd((double *)(x0 + 22 * incx)); + xv[23] = _mm_loadu_pd((double *)(x0 + 23 * incx)); + + xv[24] = _mm_loadu_pd((double *)(x0 + 24 * incx)); + xv[25] = _mm_loadu_pd((double *)(x0 + 25 * incx)); + xv[26] = _mm_loadu_pd((double *)(x0 + 26 * incx)); + xv[27] = _mm_loadu_pd((double *)(x0 + 27 * incx)); + + xv[28] = _mm_loadu_pd((double *)(x0 + 28 * incx)); + xv[29] = _mm_loadu_pd((double *)(x0 + 29 * incx)); + xv[30] = _mm_loadu_pd((double *)(x0 + 30 * incx)); + xv[31] = _mm_loadu_pd((double *)(x0 + 31 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9 ), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + _mm_storeu_pd((double *)(y0 + incy * 16), xv[16]); + _mm_storeu_pd((double *)(y0 + incy * 17), xv[17]); + _mm_storeu_pd((double *)(y0 + incy * 18), xv[18]); + _mm_storeu_pd((double *)(y0 + incy * 19), xv[19]); + + _mm_storeu_pd((double *)(y0 + incy * 20), xv[20]); + _mm_storeu_pd((double *)(y0 + incy * 21), xv[21]); + _mm_storeu_pd((double *)(y0 + incy * 22), xv[22]); + _mm_storeu_pd((double *)(y0 + incy * 23), xv[23]); + + _mm_storeu_pd((double *)(y0 + incy * 24), xv[24]); + _mm_storeu_pd((double *)(y0 + incy * 25), xv[25]); + _mm_storeu_pd((double *)(y0 + incy * 26), xv[26]); + _mm_storeu_pd((double *)(y0 + incy * 27), xv[27]); + + _mm_storeu_pd((double *)(y0 + incy * 28), xv[28]); + _mm_storeu_pd((double *)(y0 + incy * 29), xv[29]); + _mm_storeu_pd((double *)(y0 + incy * 30), xv[30]); + _mm_storeu_pd((double *)(y0 + incy * 31), xv[31]); + + // Increment the pointer + x0 += 32 * incx; + y0 += 32 * incy; + } + + for ( ; i < (n & (~0x0F)); i += 16) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + xv[8] = _mm_loadu_pd((double *)(x0 + 8 * incx)); + xv[9] = _mm_loadu_pd((double *)(x0 + 9 * incx)); + xv[10] = _mm_loadu_pd((double *)(x0 + 10 * incx)); + xv[11] = _mm_loadu_pd((double *)(x0 + 11 * incx)); + + xv[12] = _mm_loadu_pd((double *)(x0 + 12 * incx)); + xv[13] = _mm_loadu_pd((double *)(x0 + 13 * incx)); + xv[14] = _mm_loadu_pd((double *)(x0 + 14 * incx)); + xv[15] = _mm_loadu_pd((double *)(x0 + 15 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + _mm_storeu_pd((double *)(y0 + incy * 8), xv[8]); + _mm_storeu_pd((double *)(y0 + incy * 9), xv[9]); + _mm_storeu_pd((double *)(y0 + incy * 10), xv[10]); + _mm_storeu_pd((double *)(y0 + incy * 11), xv[11]); + + _mm_storeu_pd((double *)(y0 + incy * 12), xv[12]); + _mm_storeu_pd((double *)(y0 + incy * 13), xv[13]); + _mm_storeu_pd((double *)(y0 + incy * 14), xv[14]); + _mm_storeu_pd((double *)(y0 + incy * 15), xv[15]); + + // Increment the pointer + x0 += 16 * incx; + y0 += 16 * incy; + } + + for ( ; i < (n & (~0x07)); i += 8) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + xv[4] = _mm_loadu_pd((double *)(x0 + 4 * incx)); + xv[5] = _mm_loadu_pd((double *)(x0 + 5 * incx)); + xv[6] = _mm_loadu_pd((double *)(x0 + 6 * incx)); + xv[7] = _mm_loadu_pd((double *)(x0 + 7 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + _mm_storeu_pd((double *)(y0 + incy * 4), xv[4]); + _mm_storeu_pd((double *)(y0 + incy * 5), xv[5]); + _mm_storeu_pd((double *)(y0 + incy * 6), xv[6]); + _mm_storeu_pd((double *)(y0 + incy * 7), xv[7]); + + // Increment the pointer + x0 += 8 * incx; + y0 += 8 * incy; + } + + for ( ; i < (n & (~0x03)); i += 4) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + xv[2] = _mm_loadu_pd((double *)(x0 + 2 * incx)); + xv[3] = _mm_loadu_pd((double *)(x0 + 3 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + _mm_storeu_pd((double *)(y0 + incy * 2), xv[2]); + _mm_storeu_pd((double *)(y0 + incy * 3), xv[3]); + + // Increment the pointer + x0 += 4 * incx; + y0 += 4 * incy; + } + + for ( ; i < (n & (~0x01)); i += 2) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + xv[1] = _mm_loadu_pd((double *)(x0 + 1 * incx)); + + // Storing the values to desti-nation + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + _mm_storeu_pd((double *)(y0 + incy * 1), xv[1]); + + // Increment the pointer + x0 += 2 * incx; + y0 += 2 * incy; + } + + for ( ; i < n; i += 1) + { + // Loading the input values + xv[0] = _mm_loadu_pd((double *)(x0 + 0 * incx)); + + // Storing the values to destination + _mm_storeu_pd((double *)(y0 + incy * 0), xv[0]); + + // Increment the pointer + x0 += 1 * incx; + y0 += 1 * incy; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) +} diff --git a/kernels/zen4/1/bli_dotv_zen_int_avx512.c b/kernels/zen4/1/bli_dotv_zen_int_avx512.c index 4d9708e751..bb758a8ae7 100644 --- a/kernels/zen4/1/bli_dotv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_dotv_zen_int_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,6 +35,9 @@ #include "immintrin.h" #include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + /* Functionality ------------- @@ -393,3 +396,834 @@ void bli_ddotv_zen_int_avx512 // Copy the final result into the output variable. PASTEMAC(d, copys)(rho0, *rho); } + +/* + Functionality + ------------- + + This function calculates the dot product of two vectors for + type double complex. + + rho := conjx(x)^T * conjy(y) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'conjy' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n <= 0, incx <= 1 and incy <= 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ +void bli_zdotv_zen_int_avx512 + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + // Initialize local pointers. + double* restrict x0 = (double*)x; + double* restrict y0 = (double*)y; + + dcomplex rho0 = *bli_z0; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + bli_toggle_conj( &conjx_use ); + + dim_t i = 0; + if ( incx == 1 && incy == 1 ) + { + const dim_t n_elem_per_reg = 8; + + __m512d xv[8]; + __m512d yv[8]; + __m512d rhov[16]; + + // Initialize rho accumulation vectors to 0. + // rhov[0] - rhov[7] store the real part of intermediate result. + // rhov[8] - rhov[15] store the imaginary part of intermediate result. + rhov[0] = _mm512_setzero_pd(); + rhov[1] = _mm512_setzero_pd(); + rhov[2] = _mm512_setzero_pd(); + rhov[3] = _mm512_setzero_pd(); + rhov[4] = _mm512_setzero_pd(); + rhov[5] = _mm512_setzero_pd(); + rhov[6] = _mm512_setzero_pd(); + rhov[7] = _mm512_setzero_pd(); + rhov[8] = _mm512_setzero_pd(); + rhov[9] = _mm512_setzero_pd(); + rhov[10] = _mm512_setzero_pd(); + rhov[11] = _mm512_setzero_pd(); + rhov[12] = _mm512_setzero_pd(); + rhov[13] = _mm512_setzero_pd(); + rhov[14] = _mm512_setzero_pd(); + rhov[15] = _mm512_setzero_pd(); + + /** + * General Algorithm: + * + * xv[0] = x0R x0I x1R x1I ... + * yv[0] = y0R y0I y1R y1I ... + * rhov[0] = xv[0] * yv[0] + rhov[0] + * = x0R*y0R x0I*y0I x1R*y1R x1I*y0I ... + * yv[0] = permute(0x55) + * = y0I y0R y1I y1R ... + * rhov[8] = xv[0] * yv[0] + rhov[8] + * = x0R*y0I x0I*y0R x1R*y1I x1I*y1R ... + */ + + // Processing 32 dcomplex elements per iteration. + for ( ; (i + 31) < n; i += 32 ) + { + // Load elements from x vector. + xv[0] = _mm512_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm512_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm512_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm512_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm512_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm512_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm512_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm512_loadu_pd( x0 + 7*n_elem_per_reg ); + + // Load elements from y vector. + yv[0] = _mm512_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm512_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm512_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm512_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm512_loadu_pd( y0 + 4*n_elem_per_reg ); + yv[5] = _mm512_loadu_pd( y0 + 5*n_elem_per_reg ); + yv[6] = _mm512_loadu_pd( y0 + 6*n_elem_per_reg ); + yv[7] = _mm512_loadu_pd( y0 + 7*n_elem_per_reg ); + + // Operation: rhov = xv * yv + rhov + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + rhov[1] = _mm512_fmadd_pd( xv[1], yv[1], rhov[1] ); + rhov[2] = _mm512_fmadd_pd( xv[2], yv[2], rhov[2] ); + rhov[3] = _mm512_fmadd_pd( xv[3], yv[3], rhov[3] ); + rhov[4] = _mm512_fmadd_pd( xv[4], yv[4], rhov[4] ); + rhov[5] = _mm512_fmadd_pd( xv[5], yv[5], rhov[5] ); + rhov[6] = _mm512_fmadd_pd( xv[6], yv[6], rhov[6] ); + rhov[7] = _mm512_fmadd_pd( xv[7], yv[7], rhov[7] ); + + // Operation: yv -> yv' + // yv = y0R y0I y1R y1I ... + // yv' = y0I y0R y1I y1R ... + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + yv[1] = _mm512_permute_pd( yv[1], 0x55 ); + yv[2] = _mm512_permute_pd( yv[2], 0x55 ); + yv[3] = _mm512_permute_pd( yv[3], 0x55 ); + yv[4] = _mm512_permute_pd( yv[4], 0x55 ); + yv[5] = _mm512_permute_pd( yv[5], 0x55 ); + yv[6] = _mm512_permute_pd( yv[6], 0x55 ); + yv[7] = _mm512_permute_pd( yv[7], 0x55 ); + + // Operation: rhov = xv * yv' + rhov + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + rhov[9] = _mm512_fmadd_pd( xv[1], yv[1], rhov[9] ); + rhov[10] = _mm512_fmadd_pd( xv[2], yv[2], rhov[10] ); + rhov[11] = _mm512_fmadd_pd( xv[3], yv[3], rhov[11] ); + rhov[12] = _mm512_fmadd_pd( xv[4], yv[4], rhov[12] ); + rhov[13] = _mm512_fmadd_pd( xv[5], yv[5], rhov[13] ); + rhov[14] = _mm512_fmadd_pd( xv[6], yv[6], rhov[14] ); + rhov[15] = _mm512_fmadd_pd( xv[7], yv[7], rhov[15] ); + + // Increment x0 and y0 vector pointers. + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + // Accumulating intermediate results to rhov[0] and rhov[8]. + rhov[0] = _mm512_add_pd( rhov[0], rhov[4] ); + rhov[0] = _mm512_add_pd( rhov[0], rhov[5] ); + rhov[0] = _mm512_add_pd( rhov[0], rhov[6] ); + rhov[0] = _mm512_add_pd( rhov[0], rhov[7] ); + + rhov[8] = _mm512_add_pd( rhov[8], rhov[12] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[13] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[14] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[15] ); + + // Processing 16 dcomplex elements per iteration. + for ( ; (i + 15) < n; i += 16 ) + { + xv[0] = _mm512_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm512_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm512_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm512_loadu_pd( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm512_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm512_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm512_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm512_loadu_pd( y0 + 3*n_elem_per_reg ); + + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + rhov[1] = _mm512_fmadd_pd( xv[1], yv[1], rhov[1] ); + rhov[2] = _mm512_fmadd_pd( xv[2], yv[2], rhov[2] ); + rhov[3] = _mm512_fmadd_pd( xv[3], yv[3], rhov[3] ); + + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + yv[1] = _mm512_permute_pd( yv[1], 0x55 ); + yv[2] = _mm512_permute_pd( yv[2], 0x55 ); + yv[3] = _mm512_permute_pd( yv[3], 0x55 ); + + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + rhov[9] = _mm512_fmadd_pd( xv[1], yv[1], rhov[9] ); + rhov[10] = _mm512_fmadd_pd( xv[2], yv[2], rhov[10] ); + rhov[11] = _mm512_fmadd_pd( xv[3], yv[3], rhov[11] ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + rhov[0] = _mm512_add_pd( rhov[0], rhov[3] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[11] ); + + // Processing 12 dcomplex elements per iteration. + for ( ; (i + 11) < n; i += 12 ) + { + xv[0] = _mm512_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm512_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm512_loadu_pd( x0 + 2*n_elem_per_reg ); + + yv[0] = _mm512_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm512_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm512_loadu_pd( y0 + 2*n_elem_per_reg ); + + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + rhov[1] = _mm512_fmadd_pd( xv[1], yv[1], rhov[1] ); + rhov[2] = _mm512_fmadd_pd( xv[2], yv[2], rhov[2] ); + + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + yv[1] = _mm512_permute_pd( yv[1], 0x55 ); + yv[2] = _mm512_permute_pd( yv[2], 0x55 ); + + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + rhov[9] = _mm512_fmadd_pd( xv[1], yv[1], rhov[9] ); + rhov[10] = _mm512_fmadd_pd( xv[2], yv[2], rhov[10] ); + + x0 += 3 * n_elem_per_reg; + y0 += 3 * n_elem_per_reg; + } + + rhov[0] = _mm512_add_pd( rhov[0], rhov[2] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[10] ); + + // Processing 8 dcomplex elements per iteration. + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm512_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm512_loadu_pd( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm512_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm512_loadu_pd( y0 + 1*n_elem_per_reg ); + + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + rhov[1] = _mm512_fmadd_pd( xv[1], yv[1], rhov[1] ); + + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + yv[1] = _mm512_permute_pd( yv[1], 0x55 ); + + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + rhov[9] = _mm512_fmadd_pd( xv[1], yv[1], rhov[9] ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + rhov[0] = _mm512_add_pd( rhov[0], rhov[1] ); + rhov[8] = _mm512_add_pd( rhov[8], rhov[9] ); + + // Processing 4 dcomplex elements per iteration. + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm512_loadu_pd( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm512_loadu_pd( y0 + 0*n_elem_per_reg ); + + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Processing the remainder elements. + if( i < n ) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(m-i) elements. + __mmask8 mask = (1 << (2 * (n-i)) ) - 1; + + // Clearing the rhov[1] register for mask-load. + rhov[1] = _mm512_setzero_pd(); + + xv[0] = _mm512_mask_loadu_pd( rhov[1], mask, x0 ); + + yv[0] = _mm512_mask_loadu_pd( rhov[1], mask, y0 ); + + rhov[0] = _mm512_fmadd_pd( xv[0], yv[0], rhov[0] ); + + yv[0] = _mm512_permute_pd( yv[0], 0x55 ); + + rhov[8] = _mm512_fmadd_pd( xv[0], yv[0], rhov[8] ); + } + + // Initialize mask for reduce-add based on conjugate. + __m512d mask = _mm512_set_pd(-1, 1, -1, 1, -1, 1, -1, 1); + if ( bli_is_conj( conjx_use ) ) + { + rho0.real = _mm512_reduce_add_pd( rhov[0] ); + rhov[8] = _mm512_mul_pd( rhov[8], mask ); + rho0.imag = _mm512_reduce_add_pd( rhov[8] ); + } + else + { + rhov[0] = _mm512_mul_pd( rhov[0], mask ); + rho0.real = _mm512_reduce_add_pd( rhov[0] ); + rho0.imag = _mm512_reduce_add_pd( rhov[8] ); + } + } + else // Non-Unit Increments + { + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c - x1c * y1c; + rho0.imag += x0c * y1c + x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c + x1c * y1c; + rho0.imag += x0c * y1c - x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + + // Negate the sign of imaginary value when conjy is enabled. + if ( bli_is_conj( conjy ) ) + rho0.imag = -rho0.imag; + + // Copy the result to rho. + PASTEMAC(z,copys)( rho0, *rho ); +} + +/* + Functionality + ------------- + + This function calculates the dot product of two vectors for + type double complex. + + rho := conjx(x)^T * conjy(y) + + Function Signature + ------------------- + + * 'conjx' - Variable specified if x needs to be conjugated + * 'conjy' - Variable specified if x needs to be conjugated + * 'n' - Length of the array passed + * 'x' - Double pointer pointing to an array + * 'y' - Double pointer pointing to an array + * 'incx' - Stride to point to the next element in x array + * 'incy' - Stride to point to the next element in y array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n <= 0, incx <= 1 and incy <= 1. + The expectation is that these are standard BLAS exceptions and should be handled in + a higher layer +*/ +void bli_zdotv_zen4_asm_avx512 + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + // Initialize local pointers. + double* restrict x0 = (double*)x; + double* restrict y0 = (double*)y; + + dcomplex rho0 = *bli_z0; + double* restrict rho0R = &rho0.real; + double* restrict rho0I = &rho0.imag; + + // Using a local unit value for setting a unit register. + double one_l = 1.0; + double* restrict one = &one_l; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + bli_toggle_conj( &conjx_use ); + + // Copying conjx_use to a local conj variable for simple condition check + // within inline assembly. + dim_t conj = 0; + if ( bli_is_conj( conjx_use ) ) conj = 1; + + if ( incx == 1 && incy == 1 ) // Inline ASM used to handle unit-increment. + { + begin_asm() + + mov( var( n ), rsi ) // load n to rsi. + mov( var( x0 ), rax ) // load location of x vec to rax. + mov( var( y0 ), rbx ) // load location of y vec to rbx. + + // Initialize 16 registers (zmm0 - zmm15) to zero. + // These will be used for accumulation of rho. + // zmm0 - zmm7: real intermediate values of rho. + // zmm8 - zmm15: imaginary intermediate values of rho. + vxorpd( zmm0, zmm0, zmm0 ) + vxorpd( zmm1, zmm1, zmm1 ) + vxorpd( zmm2, zmm2, zmm2 ) + vxorpd( zmm3, zmm3, zmm3 ) + vxorpd( zmm4, zmm4, zmm4 ) + vxorpd( zmm5, zmm5, zmm5 ) + vxorpd( zmm6, zmm6, zmm6 ) + vxorpd( zmm7, zmm7, zmm7 ) + vxorpd( zmm8, zmm8, zmm8 ) + vxorpd( zmm9, zmm9, zmm9 ) + vxorpd( zmm10, zmm10, zmm10 ) + vxorpd( zmm11, zmm11, zmm11 ) + vxorpd( zmm12, zmm12, zmm12 ) + vxorpd( zmm13, zmm13, zmm13 ) + vxorpd( zmm14, zmm14, zmm14 ) + vxorpd( zmm15, zmm15, zmm15 ) + + + /** + * General Algorithm: + * + * zmm16 = x0R x0I x1R x1I ... + * zmm24 = y0R y0I y1R y1I ... + * zmm0 = zmm16 * zmm24 + zmm0 + * = x0R*y0R x0I*y0I x1R*y1R x1I*y0I ... + * zmm24 = permute(0x55) + * = y0I y0R y1I y1R ... + * zmm8 = zmm16 * zmm24 + zmm8 + * = x0R*y0I x0I*y0R x1R*y1I x1I*y1R ... + */ + + + // Each iteration of L32 handles 32 elements. + // Each zmm register can handle 8 doubles, i.e., 4 dcomplex elements. + // Thus, using 8 registers each for x and y vectors we handle 32 + // elements in every iteration of the loop. + label( .L32 ) + cmp( imm(32), rsi ) + jl( .ACCUM32 ) + + // Alternate loads from x & y. + vmovupd( ( rax ), zmm16 ) // load from x + vmovupd( ( rbx ), zmm24 ) // load from y + vmovupd( 0x40( rax ), zmm17 ) + vmovupd( 0x40( rbx ), zmm25 ) + vmovupd( 0x80( rax ), zmm18 ) + vmovupd( 0x80( rbx ), zmm26 ) + vmovupd( 0xC0( rax ), zmm19 ) + vmovupd( 0xC0( rbx ), zmm27 ) + vmovupd( 0x100( rax ), zmm20 ) + vmovupd( 0x100( rbx ), zmm28 ) + vmovupd( 0x140( rax ), zmm21 ) + vmovupd( 0x140( rbx ), zmm29 ) + vmovupd( 0x180( rax ), zmm22 ) + vmovupd( 0x180( rbx ), zmm30 ) + vmovupd( 0x1C0( rax ), zmm23 ) + vmovupd( 0x1C0( rbx ), zmm31 ) + + // Increment x0 and y0 vector pointers. + add( imm(512), rax ) + add( imm(512), rbx ) + + // Operation: rhov = xv * yv + rhov + vfmadd231pd( zmm16, zmm24, zmm0 ) + vfmadd231pd( zmm17, zmm25, zmm1 ) + vfmadd231pd( zmm18, zmm26, zmm2 ) + vfmadd231pd( zmm19, zmm27, zmm3 ) + vfmadd231pd( zmm20, zmm28, zmm4 ) + vfmadd231pd( zmm21, zmm29, zmm5 ) + vfmadd231pd( zmm22, zmm30, zmm6 ) + vfmadd231pd( zmm23, zmm31, zmm7 ) + + // Operation: yv -> yv' + // yv = y0R y0I y1R y1I ... + // yv' = y0I y0R y1I y1R ... + vpermilpd( imm(0x55), zmm24, zmm24 ) + vpermilpd( imm(0x55), zmm25, zmm25 ) + vpermilpd( imm(0x55), zmm26, zmm26 ) + vpermilpd( imm(0x55), zmm27, zmm27 ) + vpermilpd( imm(0x55), zmm28, zmm28 ) + vpermilpd( imm(0x55), zmm29, zmm29 ) + vpermilpd( imm(0x55), zmm30, zmm30 ) + vpermilpd( imm(0x55), zmm31, zmm31 ) + + // Operation: rhov = xv * yv' + rhov + vfmadd231pd( zmm16, zmm24, zmm8 ) + vfmadd231pd( zmm17, zmm25, zmm9 ) + vfmadd231pd( zmm18, zmm26, zmm10 ) + vfmadd231pd( zmm19, zmm27, zmm11 ) + vfmadd231pd( zmm20, zmm28, zmm12 ) + vfmadd231pd( zmm21, zmm29, zmm13 ) + vfmadd231pd( zmm22, zmm30, zmm14 ) + vfmadd231pd( zmm23, zmm31, zmm15 ) + + // Loop decrement. + sub( imm(32), rsi ) + jmp( .L32 ) + + + // Accumulating intermediate results to zmm0 and zmm8. + label( .ACCUM32 ) + vaddpd( zmm4, zmm0, zmm0 ) + vaddpd( zmm5, zmm0, zmm0 ) + vaddpd( zmm6, zmm0, zmm0 ) + vaddpd( zmm7, zmm0, zmm0 ) + + vaddpd( zmm12, zmm8, zmm8 ) + vaddpd( zmm13, zmm8, zmm8 ) + vaddpd( zmm14, zmm8, zmm8 ) + vaddpd( zmm15, zmm8, zmm8 ) + + // Each iteration of L16 handles 16 elements. + label( .L16 ) + cmp( imm(16), rsi ) + jl( .ACCUM16 ) + + // Alternate loads from x & y. + vmovupd( ( rax ), zmm16 ) // load from x + vmovupd( ( rbx ), zmm24 ) // load from y + vmovupd( 0x40( rax ), zmm17 ) + vmovupd( 0x40( rbx ), zmm25 ) + vmovupd( 0x80( rax ), zmm18 ) + vmovupd( 0x80( rbx ), zmm26 ) + vmovupd( 0xC0( rax ), zmm19 ) + vmovupd( 0xC0( rbx ), zmm27 ) + + // Increment x0 and y0 vector pointers. + add( imm(256), rax ) + add( imm(256), rbx ) + + // Operation: rhov = xv * yv + rhov + vfmadd231pd( zmm16, zmm24, zmm0 ) + vfmadd231pd( zmm17, zmm25, zmm1 ) + vfmadd231pd( zmm18, zmm26, zmm2 ) + vfmadd231pd( zmm19, zmm27, zmm3 ) + + // Operation: yv -> yv' + // yv = y0R y0I y1R y1I ... + // yv' = y0I y0R y1I y1R ... + vpermilpd( imm(0x55), zmm24, zmm24 ) + vpermilpd( imm(0x55), zmm25, zmm25 ) + vpermilpd( imm(0x55), zmm26, zmm26 ) + vpermilpd( imm(0x55), zmm27, zmm27 ) + + // Operation: rhov = xv * yv' + rhov + vfmadd231pd( zmm16, zmm24, zmm8 ) + vfmadd231pd( zmm17, zmm25, zmm9 ) + vfmadd231pd( zmm18, zmm26, zmm10 ) + vfmadd231pd( zmm19, zmm27, zmm11 ) + + // Loop decrement. + sub( imm(16), rsi ) + jmp( .L16 ) + + + // Accumulating intermediate results to zmm0 and zmm8. + label( .ACCUM16 ) + vaddpd( zmm2, zmm0, zmm0 ) + vaddpd( zmm3, zmm0, zmm0 ) + + vaddpd( zmm10, zmm8, zmm8 ) + vaddpd( zmm11, zmm8, zmm8 ) + + // Each iteration of L8 handles 8 elements. + label( .L8 ) + cmp( imm(8), rsi ) + jl( .ACCUM8 ) + + // Alternate loads from x & y. + vmovupd( ( rax ), zmm16 ) // load from x + vmovupd( ( rbx ), zmm24 ) // load from y + vmovupd( 0x40 ( rax ), zmm17 ) + vmovupd( 0x40 ( rbx ), zmm25 ) + + // Increment x0 and y0 vector pointers. + add( imm(128), rax ) + add( imm(128), rbx ) + + // Operation: rhov = xv * yv + rhov + vfmadd231pd( zmm16, zmm24, zmm0 ) + vfmadd231pd( zmm17, zmm25, zmm1 ) + + // Operation: yv -> yv' + // yv = y0R y0I y1R y1I ... + // yv' = y0I y0R y1I y1R ... + vpermilpd( imm(0x55), zmm24, zmm24 ) + vpermilpd( imm(0x55), zmm25, zmm25 ) + + // Operation: rhov = xv * yv' + rhov + vfmadd231pd( zmm16, zmm24, zmm8 ) + vfmadd231pd( zmm17, zmm25, zmm9 ) + + // Loop decrement. + sub( imm(8), rsi ) + jmp( .L8 ) + + + // Accumulating intermediate results to zmm0 and zmm8. + label( .ACCUM8 ) + vaddpd( zmm1, zmm0, zmm0 ) + vaddpd( zmm9, zmm8, zmm8 ) + + + // Each iteration of L4 handles 4 elements. + label( .L4 ) + cmp( imm(4), rsi ) + jl( .FRINGE ) + + // Alternate loads from x & y. + vmovupd( ( rax ), zmm16 ) // load from x + vmovupd( ( rbx ), zmm24 ) // load from y + + // Increment x0 and y0 vector pointers. + add( imm(64), rax ) + add( imm(64), rbx ) + + // Operation: rhov = xv * yv + rhov + vfmadd231pd( zmm16, zmm24, zmm0 ) + + // Operation: yv -> yv' + // yv = y0R y0I y1R y1I ... + // yv' = y0I y0R y1I y1R ... + vpermilpd( imm(0x55), zmm24, zmm24 ) + + // Operation: rhov = xv * yv' + rhov + vfmadd231pd( zmm16, zmm24, zmm8 ) + + // Loop decrement. + sub( imm(4), rsi ) + jmp( .L4 ) + + + // Fringe case to process the remainder elements. + LABEL( .FRINGE ) + cmp( imm(0x0), rsi ) + je( .CONJ ) + + vxorpd( zmm16, zmm16, zmm16 ) + vxorpd( zmm24, zmm24, zmm24 ) + mov( imm(255), ecx ) + shlx( esi, ecx, ecx ) + shlx( esi, ecx, ecx ) + xor( imm(255), ecx ) + kmovw( ecx, K(1) ) + + vmovupd( mem(rax), zmm16 MASK_(K(1)) ) + + vmovupd( mem(rbx), zmm24 MASK_(K(1)) ) + + vfmadd231pd( zmm16, zmm24, zmm0 ) + + vpermilpd( imm(0x55), zmm24, zmm24 ) + + vfmadd231pd( zmm16, zmm24, zmm8 ) + + + // Handling conjugates. + LABEL( .CONJ ) + // set zmm1 to all zeros + vxorpd( xmm1, xmm1, xmm1 ) + // broadcast one (1) to zmm2 + mov( var(one), rax ) + vbroadcastsd( (rax), zmm2 ) + vfmsubadd231pd( zmm1, zmm2, zmm2 ) + + // load rho0R and rho0I into memory. + mov( var(rho0R), rax ) + mov( var(rho0I), rbx ) + + mov( var(conj), rcx) + cmp( imm(0x0), rcx ) + je( .NOCONJX) + + // if conjx_use + label( .CONJX ) + vextractf64x4( imm(0x1), zmm0, ymm2 ) + vaddpd( ymm0, ymm2, ymm0 ) + vextractf128( imm(0x1), ymm0, xmm2 ) + vaddpd( xmm2, xmm0, xmm0 ) + vshufpd( imm(0x1), xmm0, xmm0, xmm2 ) + vaddpd( xmm2, xmm0, xmm0 ) + vmovupd( xmm0, (rax) ) // store result to rho0R + + vmulpd( zmm1, zmm8, zmm8 ) + vextractf64x4( imm(0x1), zmm8, ymm2 ) + vaddpd( ymm8, ymm2, ymm8 ) + vextractf128( imm(0x1), ymm8, xmm2 ) + vaddpd( xmm2, xmm8, xmm8 ) + vshufpd( imm(0x1), xmm8, xmm8, xmm2 ) + vaddpd( xmm2, xmm8, xmm8 ) + vmovupd( xmm8, (rbx) ) // store result to rho0I + jmp( .END ) + + // if !conjx_use + label( .NOCONJX ) + vmulpd( zmm2, zmm0, zmm0 ) + vextractf64x4( imm(0x1), zmm0, ymm2 ) + vaddpd( ymm0, ymm2, ymm0 ) + vextractf128( imm(0x1), ymm0, xmm2 ) + vaddpd( xmm2, xmm0, xmm0 ) + vshufpd( imm(0x1), xmm0, xmm0, xmm2 ) + vaddpd( xmm2, xmm0, xmm0 ) + vmovupd( xmm0, (rax) ) // store result to rho0R + + vextractf64x4( imm(0x1), zmm8, ymm2 ) + vaddpd( ymm8, ymm2, ymm8 ) + vextractf128( imm(0x1), ymm8, xmm2 ) + vaddpd( xmm2, xmm8, xmm8 ) + vshufpd( imm(0x1), xmm8, xmm8, xmm2 ) + vaddpd( xmm2, xmm8, xmm8 ) + vmovupd( xmm8, (rbx) ) // store result to rho0I + + label( .END ) + + end_asm( + : // output operands (none) + : // input operands + [n] "m" (n), + [x0] "m" (x0), + [y0] "m" (y0), + [rho0R] "m" (rho0R), + [rho0I] "m" (rho0I), + [one] "m" (one), + [conj] "m" (conj) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm12", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k1", "xmm8", "ymm0", "ymm2", "ymm8", "memory" + ) + + rho0.real = *rho0R; + rho0.imag = *rho0I; + } + else // Non-Unit Increments + { + dim_t i = 0; + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c - x1c * y1c; + rho0.imag += x0c * y1c + x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c + x1c * y1c; + rho0.imag += x0c * y1c - x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + + // Negate the sign of imaginary value when conjy is enabled. + if ( bli_is_conj( conjy ) ) + rho0.imag = -rho0.imag; + + // Copy the result to rho. + PASTEMAC(z,copys)( rho0, *rho ); +} diff --git a/kernels/zen4/1/bli_dotxv_zen_int_avx512.c b/kernels/zen4/1/bli_dotxv_zen_int_avx512.c new file mode 100644 index 0000000000..01ef9dec02 --- /dev/null +++ b/kernels/zen4/1/bli_dotxv_zen_int_avx512.c @@ -0,0 +1,382 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union data structure to access AVX-512 registers +* One 512-bit AVX register holds 8 DP elements. */ +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + +// ----------------------------------------------------------------------------- + +void bli_zdotxv_zen_int_avx512 + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict beta, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + dim_t i = 0; + + dcomplex* restrict x0; + dcomplex* restrict y0; + dcomplex rho0; + + // Performing XOR of conjx and conjy. + // conj_op is set if either X or Y has conjugate(not both) + conj_t conj_op = conjx ^ conjy; + + // If beta is zero, initialize rho to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(z,eq0)( *beta ) ) + { + PASTEMAC(z,set0s)( *rho ); + } + else + { + PASTEMAC(z,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Computation to handle unit-stride cases + if ( incx == 1 && incy == 1 ) + { + dim_t n_elem_per_reg = 4; + + // Declaring 8 registers, to store partial sums over multiple loads + // Further declaring 4 registers for loading X and 8 for loading + // and permuting Y for complex datatype arithmetic. + v8df_t rhov[8], xv[4], yv[8]; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm512_setzero_pd(); + rhov[1].v = _mm512_setzero_pd(); + rhov[2].v = _mm512_setzero_pd(); + rhov[3].v = _mm512_setzero_pd(); + + rhov[4].v = _mm512_setzero_pd(); + rhov[5].v = _mm512_setzero_pd(); + rhov[6].v = _mm512_setzero_pd(); + rhov[7].v = _mm512_setzero_pd(); + + // Setting 2 vectors to 0 and 1 for the compute. + v8df_t zero_reg, scale_one; + zero_reg.v = _mm512_setzero_pd(); + scale_one.v = _mm512_set1_pd(1.0); + + // Checking to see if we should take the unmasked vector code + if( n >= 4 ) + { + for (; ( i + 15 ) < n; i += 16 ) + { + // Load elements from X and Y + xv[0].v = _mm512_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm512_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm512_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm512_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm512_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm512_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm512_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm512_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + // Permute to duplicate the imag part for every element + // yv[4].v = I0 I0 I1 I1 ... + yv[4].v = _mm512_permute_pd( yv[0].v, 0xFF ); + yv[5].v = _mm512_permute_pd( yv[1].v, 0xFF ); + yv[6].v = _mm512_permute_pd( yv[2].v, 0xFF ); + yv[7].v = _mm512_permute_pd( yv[3].v, 0xFF ); + + // Permute to duplicate the real part for every element + // yv[0].v = R0 R0 R1 R1 ... + yv[0].v = _mm512_permute_pd( yv[0].v, 0x00 ); + yv[1].v = _mm512_permute_pd( yv[1].v, 0x00 ); + yv[2].v = _mm512_permute_pd( yv[2].v, 0x00 ); + yv[3].v = _mm512_permute_pd( yv[3].v, 0x00 ); + + // Compute the element-wise product of the X and Y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm512_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm512_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm512_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm512_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm512_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm512_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm512_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm512_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + // Adjust the pointers accordingly + x0 += ( n_elem_per_reg * 4 ); + y0 += ( n_elem_per_reg * 4 ); + } + for (; ( i + 7 ) < n; i += 8 ) + { + // Load elements from X and Y + xv[0].v = _mm512_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm512_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm512_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm512_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + // Permute to duplicate the imag part for every element + // yv[4].v = I0 I0 I1 I1 ... + yv[4].v = _mm512_permute_pd( yv[0].v, 0xFF ); + yv[5].v = _mm512_permute_pd( yv[1].v, 0xFF ); + + // Permute to duplicate the real part for every element + // yv[0].v = R0 R0 R1 R1 ... + yv[0].v = _mm512_permute_pd( yv[0].v, 0x00 ); + yv[1].v = _mm512_permute_pd( yv[1].v, 0x00 ); + + // Compute the element-wise product of the X and Y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm512_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm512_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + + rhov[4].v = _mm512_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm512_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + + // Adjust the pointers accordingly + x0 += ( n_elem_per_reg * 2 ); + y0 += ( n_elem_per_reg * 2 ); + } + for (; ( i + 3 ) < n; i += 4 ) + { + // Load elements from X and Y + xv[0].v = _mm512_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm512_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + // Permute to duplicate the imag part for every element + // yv[4].v = I0 I0 I1 I1 ... + yv[4].v = _mm512_permute_pd( yv[0].v, 0xFF ); + + // Permute to duplicate the real part for every element + // yv[0].v = R0 R0 R1 R1 ... + yv[0].v = _mm512_permute_pd( yv[0].v, 0x00 ); + + // Compute the element-wise product of the X and Y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm512_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + + rhov[4].v = _mm512_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + + x0 += ( n_elem_per_reg * 1 ); + y0 += ( n_elem_per_reg * 1 ); + } + } + if ( i < n ) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(n-i) elements. + __mmask8 n_mask = (1 << 2*(n - i)) - 1; + + // Load elements from X and Y + xv[0].v = _mm512_maskz_loadu_pd(n_mask, (double *)x0 ); + yv[0].v = _mm512_maskz_loadu_pd(n_mask, (double *)y0 ); + + // Permute to duplicate the imag part for every element + // yv[4].v = I0 I0 I1 I1 ... + yv[4].v = _mm512_permute_pd( yv[0].v, 0xFF ); + + // Permute to duplicate the real part for every element + // yv[0].v = R0 R0 R1 R1 ... + yv[0].v = _mm512_permute_pd( yv[0].v, 0x00 ); + + // Compute the element-wise product of the X and Y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm512_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + + rhov[4].v = _mm512_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + } + + // Permuting for final accumulation of real and imag parts + rhov[4].v = _mm512_permute_pd(rhov[4].v, 0x55); + rhov[5].v = _mm512_permute_pd(rhov[5].v, 0x55); + rhov[6].v = _mm512_permute_pd(rhov[6].v, 0x55); + rhov[7].v = _mm512_permute_pd(rhov[7].v, 0x55); + + // Accumulate the unrolled rho vectors into a single vector + // rhov[0] contains element by element real-part scaling + // rhov[4] contains element by element imag-part scaling + rhov[0].v = _mm512_add_pd(rhov[1].v, rhov[0].v); + rhov[2].v = _mm512_add_pd(rhov[3].v, rhov[2].v); + rhov[0].v = _mm512_add_pd(rhov[2].v, rhov[0].v); + + rhov[4].v = _mm512_add_pd(rhov[5].v, rhov[4].v); + rhov[6].v = _mm512_add_pd(rhov[7].v, rhov[6].v); + rhov[4].v = _mm512_add_pd(rhov[6].v, rhov[4].v); + + /* + conj_op maps to the compute as follows : + A = (a + ib), X = (x + iy) + ----------------------------------------------------------- + | A | X | Real part | Imag Part | + ----------------------------------------------------------- + | No-Conjugate | No-Conjugate | ax - by | bx + ay | + | No-Conjugate | Conjugate | ax + by | bx - ay | + | Conjugate | No-Conjugate | ax + by | -(bx - ay) | + | Conjugate | Conjugate | ax - by | -(bx + ay) | + ----------------------------------------------------------- + + If only X or A has conjugate, fmsubadd is performed. + Else, fmaddsub is performed. + + In the final reduction step, the imaginary part of every + partial sum is negated if conjat is true + */ + + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm512_fmaddsub_pd(scale_one.v, rhov[0].v, rhov[4].v); + } + else + { + rhov[0].v = _mm512_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[4].v); + } + + // Negate the imaginary part if conjy is congutgate + if ( bli_is_conj( conjx ) ) + { + rhov[0].v = _mm512_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + } + + // Intermediate registers for final reduction + v4df_t inter[2]; + + inter[0].v = _mm512_extractf64x4_pd(rhov[0].v, 0x00); + inter[1].v = _mm512_extractf64x4_pd(rhov[0].v, 0x01); + + inter[0].v = _mm256_add_pd(inter[1].v, inter[0].v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter[0].d[0] + inter[0].d[2]; + rho0.imag = inter[0].d[1] + inter[0].d[3]; + + } + else + { + v2df_t rhov[2], xv, yv[2]; + + rhov[0].v = _mm_setzero_pd(); + rhov[1].v = _mm_setzero_pd(); + + for(; i < n; i += 1) + { + // Load elements from X and Y + xv.v = _mm_loadu_pd((double *)x0 ); + yv[0].v = _mm_loadu_pd((double *)y0 ); + + // Permute to duplicate the imag part for every element + // yv[1].v = I0 I0 + yv[1].v = _mm_permute_pd( yv[0].v, 0b11 ); + + // Permute to duplicate the real part for every element + // yv[0].v = R0 R0 + yv[0].v = _mm_permute_pd( yv[0].v, 0b00 ); + + // Compute the element-wise product of the X and Y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm_fmadd_pd( xv.v, yv[0].v, rhov[0].v ); + + rhov[1].v = _mm_fmadd_pd( xv.v, yv[1].v, rhov[1].v ); + + x0 += incx; + y0 += incy; + } + + // Permute for final reduction + rhov[1].v = _mm_permute_pd(rhov[1].v, 0x01); + + v2df_t zero_reg, scale_one; + + zero_reg.v = _mm_setzero_pd(); + scale_one.v = _mm_set1_pd(1.0); + + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm_addsub_pd(rhov[0].v, rhov[1].v); + } + else + { + rhov[0].v = _mm_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[1].v); + } + if( bli_is_conj( conjx ) ) + { + rhov[0].v = _mm_fmsubadd_pd(zero_reg.v, rhov[0].v, rhov[0].v); + } + + rho0.real = rhov[0].d[0]; + rho0.imag = rhov[0].d[1]; + } + + // Accumulate the final result into the output variable. + PASTEMAC(z,axpys)( *alpha, rho0, *rho ); +} diff --git a/kernels/zen4/1/bli_norm2_zen_int_avx512.c b/kernels/zen4/1/bli_norm2_zen_int_avx512.c new file mode 100644 index 0000000000..14d0b72d77 --- /dev/null +++ b/kernels/zen4/1/bli_norm2_zen_int_avx512.c @@ -0,0 +1,761 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "blis.h" + +// Union data structure to access AVX registers +// One 512-bit AVX register holds 8 DP elements. +typedef union +{ + __m512d v; + double d[8] __attribute__( ( aligned( 64 ) ) ); +} v8df_t; + +/* + Optimized kernel that computes the Frobenius norm using AVX512 intrinsics. + The kernel takes in the following input parameters : + * n - Size of the vector + * x - Pointer to the vector's memory + * incx - Input stride of the vector + * norm - Pointer to the result's memory + * cntx - Context, set based on the configuration +*/ +void bli_dnorm2fv_unb_var1_avx512 + ( + dim_t n, + double* x, inc_t incx, + double* norm, + cntx_t* cntx + ) +{ + AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); + + // Local variables and pointers used for the computation + double sumsq = 0; + + // Local pointer alias to the input vector + double *xt = x; + + // Compute the sum of squares on 3 accumulators to avoid overflow + // and underflow, depending on the vector element value. + // Accumulator for small values; using scaling to avoid underflow. + double sum_sml = 0; + // Accumulator for medium values; no scaling required. + double sum_med = 0; + // Accumulator for big values; using scaling to avoid overflow. + double sum_big = 0; + + // Constants chosen to minimize roundoff, according to Blue's algorithm. + const double thresh_sml = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ); + const double thresh_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); + const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); + const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP + 52 ) * 0.5 ) ); + + // Scaling factor to be set and used in the final accumulation + double scale; + + // Boolean to check if any value > thresh_big has been encountered + bool isbig = false; + + // Iterator + dim_t i = 0; + + // In case of unit-strided input + if( incx == 1 ) + { + // AVX-512 code-section + // Declaring registers for loading, accumulation, thresholds and scale factors + v8df_t x_vec[4], sum_sml_vec[4], sum_med_vec[4], sum_big_vec[4], temp[4]; + v8df_t thresh_sml_vec, thresh_big_vec, scale_sml_vec, scale_big_vec; + v8df_t zero_reg; + + // Masks to be used in computation + __mmask8 k_mask[8]; + + // Containers to hold the results of operations on mask registers + // Bitwise operations on 8-bit mask registers would return an + // unsigned char as its result(0 or 1) + unsigned char truth_val[4]; + + // Setting the thresholds and scaling factors + thresh_sml_vec.v = _mm512_set1_pd( thresh_sml ); + thresh_big_vec.v = _mm512_set1_pd( thresh_big ); + scale_sml_vec.v = _mm512_set1_pd( scale_sml ); + scale_big_vec.v = _mm512_set1_pd( scale_big ); + + // Resetting the accumulators + sum_sml_vec[0].v = _mm512_setzero_pd(); + sum_sml_vec[1].v = _mm512_setzero_pd(); + sum_sml_vec[2].v = _mm512_setzero_pd(); + sum_sml_vec[3].v = _mm512_setzero_pd(); + + sum_med_vec[0].v = _mm512_setzero_pd(); + sum_med_vec[1].v = _mm512_setzero_pd(); + sum_med_vec[2].v = _mm512_setzero_pd(); + sum_med_vec[3].v = _mm512_setzero_pd(); + + sum_big_vec[0].v = _mm512_setzero_pd(); + sum_big_vec[1].v = _mm512_setzero_pd(); + sum_big_vec[2].v = _mm512_setzero_pd(); + sum_big_vec[3].v = _mm512_setzero_pd(); + + zero_reg.v = _mm512_setzero_pd(); + + // Computing in blocks of 32 + for ( ; ( i + 32 ) <= n; i = i + 32 ) + { + // Set temp[0..3] to zero + temp[0].v = _mm512_setzero_pd(); + temp[1].v = _mm512_setzero_pd(); + temp[2].v = _mm512_setzero_pd(); + temp[3].v = _mm512_setzero_pd(); + + // Loading the vectors + x_vec[0].v = _mm512_loadu_pd( xt ); + x_vec[1].v = _mm512_loadu_pd( xt + 8 ); + x_vec[2].v = _mm512_loadu_pd( xt + 16 ); + x_vec[3].v = _mm512_loadu_pd( xt + 24 ); + + // Comparing to check for NaN + // Bits in the mask are set if NaN is encountered + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q ); + k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q ); + k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, x_vec[2].v, _CMP_UNORD_Q ); + k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, x_vec[3].v, _CMP_UNORD_Q ); + + // Checking if any bit in the masks are set + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN + // truth_val[1] = 0 if x_vec[2].v or x_vec[3].v has NaN + truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] ); + truth_val[1] = _kortestz_mask8_u8( k_mask[2], k_mask[3] ); + + // Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0 + if( !( truth_val[0] && truth_val[1] ) ) + { + *norm = NAN; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + + // Getting the absoulte values of elements in the vectors + x_vec[0].v = _mm512_abs_pd( x_vec[0].v ); + x_vec[1].v = _mm512_abs_pd( x_vec[1].v ); + x_vec[2].v = _mm512_abs_pd( x_vec[2].v ); + x_vec[3].v = _mm512_abs_pd( x_vec[3].v ); + + // Setting the masks by comparing with thresh_sml_vec.v + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v + // k_mask[2][i] = 1 if x_vec[2].v[i] > thresh_sml_vec.v + // k_mask[3][i] = 1 if x_vec[3].v[i] > thresh_sml_vec.v + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS ); + k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS ); + k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_sml_vec.v, _CMP_GT_OS ); + k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_sml_vec.v, _CMP_GT_OS ); + + // Setting the masks by comparing with thresh_big_vec.v + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v + // k_mask[6][i] = 1 if x_vec[2].v[i] < thresh_big_vec.v + // k_mask[7][i] = 1 if x_vec[3].v[i] < thresh_big_vec.v + k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS ); + k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS ); + k_mask[6] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_big_vec.v, _CMP_LT_OS ); + k_mask[7] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_big_vec.v, _CMP_LT_OS ); + + // Setting the masks to filter only the elements within the thresholds + // k_mask[0 ... 3] contain masks for elements > thresh_sml + // k_mask[4 ... 7] contain masks for elements < thresh_big + // Thus, AND operation on these would give elements within these thresholds + k_mask[4] = _kand_mask8( k_mask[0], k_mask[4] ); + k_mask[5] = _kand_mask8( k_mask[1], k_mask[5] ); + k_mask[6] = _kand_mask8( k_mask[2], k_mask[6] ); + k_mask[7] = _kand_mask8( k_mask[3], k_mask[7] ); + + // Setting booleans to check for underflow/overflow handling + // In case of having values outside threshold, the associated + // bit in k_mask[4 ... 7] is 0. + // Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds + // truth_val[1] = 0 if x_vec[1].v has elements outside thresholds + // truth_val[2] = 0 if x_vec[2].v has elements outside thresholds + // truth_val[3] = 0 if x_vec[3].v has elements outside thresholds + truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] ); + truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] ); + truth_val[2] = _kortestc_mask8_u8( k_mask[6], k_mask[6] ); + truth_val[3] = _kortestc_mask8_u8( k_mask[7], k_mask[7] ); + + // Computing using masked fmadds, that carries over values from + // accumulator register if the mask bit is 0 + sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] ); + sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] ); + sum_med_vec[2].v = _mm512_mask3_fmadd_pd( x_vec[2].v, x_vec[2].v, sum_med_vec[2].v, k_mask[6] ); + sum_med_vec[3].v = _mm512_mask3_fmadd_pd( x_vec[3].v, x_vec[3].v, sum_med_vec[3].v, k_mask[7] ); + + // In case of having elements outside the threshold + if( !( truth_val[0] && truth_val[1] && truth_val[2] && truth_val[3] ) ) + { + // Acquiring the masks for numbers greater than thresh_big + // k_mask[4 ... 7] contain masks for elements within the thresholds + // k_mask[0 ... 3] contain masks for elements > thresh_sml. This would + // include both elements < thresh_big and >= thresh_big + // XOR on these will produce masks for elements >= thresh_big + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v + // k_mask[6][i] = 1 if x_vec[2].v[i] >= thresh_big_vec.v + // k_mask[7][i] = 1 if x_vec[3].v[i] >= thresh_big_vec.v + k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] ); + k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] ); + k_mask[6] = _kxor_mask8( k_mask[2], k_mask[6] ); + k_mask[7] = _kxor_mask8( k_mask[3], k_mask[7] ); + + // Inverting k_mask[0 ... 3], to obtain masks for elements <= thresh_sml + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v + // k_mask[2][i] = 1 if x_vec[2].v[i] <= thresh_sml_vec.v + // k_mask[3][i] = 1 if x_vec[3].v[i] <= thresh_sml_vec.v + k_mask[0] = _knot_mask8( k_mask[0] ); + k_mask[1] = _knot_mask8( k_mask[1] ); + k_mask[2] = _knot_mask8( k_mask[2] ); + k_mask[3] = _knot_mask8( k_mask[3] ); + + // Checking whether we have values greater than thresh_big + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v + // truth_val[3] = 0 if x_vec[2].v or x_vec[3].v has elements >= thresh_big_vec.v + truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] ); + truth_val[3] = _kortestz_mask8_u8( k_mask[6], k_mask[7] ); + + // In case of having values greater than thresh_big + if( !( truth_val[2] && truth_val[3] ) ) + { + // Set isbig to true + isbig = true; + + // Computing by breaking it into masked muls and fmadds + // This computation involves only the elements that + // are greater than thresh_big + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v ); + temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v ); + temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[6], scale_big_vec.v, x_vec[2].v ); + temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[7], scale_big_vec.v, x_vec[3].v ); + + // Square and add the elements to the accumulators + sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v ); + sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v ); + sum_big_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_big_vec[2].v ); + sum_big_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_big_vec[3].v ); + } + else if( !isbig ) + { + // Computing by breaking it into muls and adds + // This computation involves only the elements that + // are lesser than thresh_sml, if needed + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v ); + temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v ); + temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[2], scale_sml_vec.v, x_vec[2].v ); + temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[3], scale_sml_vec.v, x_vec[3].v ); + + // Square and add the elements to the accumulators + sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v ); + sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v ); + sum_sml_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_sml_vec[2].v ); + sum_sml_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_sml_vec[3].v ); + } + } + + // Updating the pointer for the next iteration + xt += 32; + } + + // Computing in blocks of 16 + for ( ; ( i + 16 ) <= n; i = i + 16 ) + { + // Set temp[0..1] to zero + temp[0].v = _mm512_setzero_pd(); + temp[1].v = _mm512_setzero_pd(); + + // Loading the vectors + x_vec[0].v = _mm512_loadu_pd( xt ); + x_vec[1].v = _mm512_loadu_pd( xt + 8 ); + + // Comparing to check for NaN + // Bits in the mask are set if NaN is encountered + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q ); + k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q ); + + // Checking if any bit in the masks are set + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN + truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] ); + + // Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0 + if( !truth_val[0] ) + { + *norm = NAN; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + + // Getting the absoulte values of elements in the vectors + x_vec[0].v = _mm512_abs_pd( x_vec[0].v ); + x_vec[1].v = _mm512_abs_pd( x_vec[1].v ); + + // Setting the masks by comparing with thresh_sml_vec.v + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS ); + k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS ); + + // Setting the masks by comparing with thresh_big_vec.v + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v + k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS ); + k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS ); + + // Setting the masks to filter only the elements within the thresholds + // k_mask[0 ... 1] contain masks for elements > thresh_sml + // k_mask[4 ... 5] contain masks for elements < thresh_big + // Thus, AND operation on these would give elements within these thresholds + k_mask[4] = _kand_mask8( k_mask[0], k_mask[4] ); + k_mask[5] = _kand_mask8( k_mask[1], k_mask[5] ); + + // Setting booleans to check for underflow/overflow handling + // In case of having values outside threshold, the associated + // bit in k_mask[4 ... 7] is 0. + // Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds + // truth_val[1] = 0 if x_vec[1].v has elements outside thresholds + truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] ); + truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] ); + + // Computing using masked fmadds, that carries over values from + // accumulator register if the mask bit is 0 + sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] ); + sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] ); + + // In case of having elements outside the threshold + if( !( truth_val[0] && truth_val[1] ) ) + { + // Acquiring the masks for numbers greater than thresh_big + // k_mask[4 ... 5] contain masks for elements within the thresholds + // k_mask[0 ... 1] contain masks for elements > thresh_sml. This would + // include both elements < thresh_big and >= thresh_big + // XOR on these will produce masks for elements >= thresh_big + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v + k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] ); + k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] ); + + // Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v + k_mask[0] = _knot_mask8( k_mask[0] ); + k_mask[1] = _knot_mask8( k_mask[1] ); + + // Checking whether we have values greater than thresh_big + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v + truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] ); + + // In case of having values greater than thresh_big + if( !truth_val[2] ) + { + // Set isbig to true + isbig = true; + + // Computing by breaking it into masked muls and fmadds + // This computation involves only the elements that + // are greater than thresh_big + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v ); + temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v ); + + // Square and add the elements to the accumulators + sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v ); + sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v ); + } + else if( !isbig ) + { + // Computing by breaking it into muls and adds + // This computation involves only the elements that + // are lesser than thresh_sml, if needed + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v ); + temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v ); + + // Square and add the elements to the accumulators + sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v ); + sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v ); + } + } + + // Updating the pointer for the next iteration + xt += 16; + } + for ( ; ( i + 8 ) <= n; i = i + 8 ) + { + // Set temp[0].v to zero + temp[0].v = _mm512_setzero_pd(); + + // Loading the vectors + x_vec[0].v = _mm512_loadu_pd( xt ); + + // Comparing to check for NaN + // Bits in the mask are set if NaN is encountered + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q ); + + // Checking if any bit in the masks are set + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN + truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] ); + + // Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0 + if( !truth_val[0] ) + { + *norm = NAN; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + + // Getting the absoulte values of elements in the vectors + x_vec[0].v = _mm512_abs_pd( x_vec[0].v ); + + // Setting the masks by comparing with thresh_sml_vec.v + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS ); + + // Setting the masks by comparing with thresh_big_vec.v + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v + k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS ); + + // Setting the masks to filter only the elements within the thresholds + // k_mask[0] contain masks for elements > thresh_sml + // k_mask[4] contain masks for elements < thresh_big + // Thus, AND operation on these would give elements within these thresholds + k_mask[4] = _kand_mask8( k_mask[0], k_mask[4] ); + + // Setting booleans to check for underflow/overflow handling + // In case of having values outside threshold, the associated + // bit in k_mask[4] is 0. + // Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds + truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] ); + + // Computing using masked fmadds, that carries over values from + // accumulator register if the mask bit is 0 + sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] ); + + // In case of having elements outside the threshold + if( !truth_val[0] ) + { + // Acquiring the masks for numbers greater than thresh_big + // k_mask[4 ... 5] contain masks for elements within the thresholds + // k_mask[0 ... 1] contain masks for elements > thresh_sml. This would + // include both elements < thresh_big and >= thresh_big + // XOR on these will produce masks for elements >= thresh_big + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v + k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] ); + + // Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v + k_mask[0] = _knot_mask8( k_mask[0] ); + + // Checking whether we have values greater than thresh_big + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v + truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] ); + + // In case of having values greater than thresh_big + if( !truth_val[2] ) + { + // Set isbig to true + isbig = true; + + // Computing by breaking it into masked muls and fmadds + // This computation involves only the elements that + // are greater than thresh_big + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v ); + + // Square and add the elements to the accumulators + sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v ); + } + else if( !isbig ) + { + // Computing by breaking it into muls and adds + // This computation involves only the elements that + // are lesser than thresh_sml, if needed + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v ); + + // Square and add the elements to the accumulators + sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v ); + } + } + + // Updating the pointer for the next iteration + xt += 8; + } + if( i < n ) + { + // Set temp[0].v to zero + temp[0].v = _mm512_setzero_pd(); + + // Setting the mask to load + k_mask[0] = ( 1 << ( n - i ) ) - 1; + + // Loading the vectors + x_vec[0].v = _mm512_maskz_loadu_pd( k_mask[0], xt ); + + // Comparing to check for NaN + // Bits in the mask are set if NaN is encountered + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q ); + + // Checking if any bit in the masks are set + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN + truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] ); + + // Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0 + if( !truth_val[0] ) + { + *norm = NAN; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + + // Getting the absoulte values of elements in the vectors + x_vec[0].v = _mm512_abs_pd( x_vec[0].v ); + + // Setting the masks by comparing with thresh_sml_vec.v + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v + k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS ); + + // Setting the masks by comparing with thresh_big_vec.v + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v + k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS ); + + // Setting the masks to filter only the elements within the thresholds + // k_mask[0] contain masks for elements > thresh_sml + // k_mask[4] contain masks for elements < thresh_big + // Thus, AND operation on these would give elements within these thresholds + k_mask[4] = _kand_mask8( k_mask[0], k_mask[4] ); + + // Setting booleans to check for underflow/overflow handling + // In case of having values outside threshold, the associated + // bit in k_mask[4] is 0. + // Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds + truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] ); + + // Computing using masked fmadds, that carries over values from + // accumulator register if the mask bit is 0 + sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] ); + + // In case of having elements outside the threshold + if( !truth_val[0] ) + { + // Acquiring the masks for numbers greater than thresh_big + // k_mask[4 ... 5] contain masks for elements within the thresholds + // k_mask[0 ... 1] contain masks for elements > thresh_sml. This would + // include both elements < thresh_big and >= thresh_big + // XOR on these will produce masks for elements >= thresh_big + // That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v + // k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v + k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] ); + + // Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml + // That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v + // k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v + k_mask[0] = _knot_mask8( k_mask[0] ); + + // Checking whether we have values greater than thresh_big + // The truth_val is set to 0 if any bit in the mask is 1 + // Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v + truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] ); + + // In case of having values greater than thresh_big + if( !truth_val[2] ) + { + // Set isbig to true + isbig = true; + + // Computing by breaking it into masked muls and fmadds + // This computation involves only the elements that + // are greater than thresh_big + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v ); + + // Square and add the elements to the accumulators + sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v ); + } + else if( !isbig ) + { + // Computing by breaking it into muls and adds + // This computation involves only the elements that + // are lesser than thresh_sml, if needed + + // Scale the required elements in x_vec[0..3] by scale_smal + temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v ); + + // Square and add the elements to the accumulators + sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v ); + } + } + } + + // Reduction step + // Combining the results of accumulators for each category + sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[1].v ); + sum_med_vec[2].v = _mm512_add_pd( sum_med_vec[2].v, sum_med_vec[3].v ); + sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[2].v ); + + sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[1].v ); + sum_big_vec[2].v = _mm512_add_pd( sum_big_vec[2].v, sum_big_vec[3].v ); + sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[2].v ); + + sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[1].v ); + sum_sml_vec[2].v = _mm512_add_pd( sum_sml_vec[2].v, sum_sml_vec[3].v ); + sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[2].v ); + + // Final accumulation on the scalars + sum_sml += sum_sml_vec[0].d[0] + sum_sml_vec[0].d[1] + sum_sml_vec[0].d[2] + sum_sml_vec[0].d[3] + + sum_sml_vec[0].d[4] + sum_sml_vec[0].d[5] + sum_sml_vec[0].d[6] + sum_sml_vec[0].d[7]; + sum_med += sum_med_vec[0].d[0] + sum_med_vec[0].d[1] + sum_med_vec[0].d[2] + sum_med_vec[0].d[3] + + sum_med_vec[0].d[4] + sum_med_vec[0].d[5] + sum_med_vec[0].d[6] + sum_med_vec[0].d[7]; + sum_big += sum_big_vec[0].d[0] + sum_big_vec[0].d[1] + sum_big_vec[0].d[2] + sum_big_vec[0].d[3] + + sum_big_vec[0].d[4] + sum_big_vec[0].d[5] + sum_big_vec[0].d[6] + sum_big_vec[0].d[7]; + } + // Dealing with non-unit strided inputs + else + { + // Dealing with fringe cases + double abs_chi; + for( ; i < n; i += 1 ) + { + abs_chi = bli_fabs( *xt ); + // Any thread encountering a NAN sets the sum_med accumalator to NAN + if ( bli_isnan( abs_chi ) ) + { + *norm = NAN; + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + // Most likely case: medium values, not over/under-flow. + else if ( ( abs_chi <= thresh_big ) && ( abs_chi >= thresh_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thresh_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thresh_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + + xt += incx; + } + } + + // Combine accumulators. + if ( isbig ) + { + // Combine sum_big and sum_med if sum_med > 0. + if ( sum_med > 0.0 ) + { + sum_big += ( sum_med * scale_big ) * scale_big; + } + scale = 1.0 / scale_big; + sumsq = sum_big; + } + + else if ( sum_sml > 0.0 ) + { + // Combine sum_med and sum_sml if sum_sml>0. + if ( sum_med > 0.0 ) + { + sum_med = sqrt( sum_med ); + sum_sml = sqrt( sum_sml ) / scale_sml; + double ymin, ymax; + if ( sum_sml > sum_med ) + { + ymin = sum_med; + ymax = sum_sml; + } + else + { + ymin = sum_sml; + ymax = sum_med; + } + scale = 1.0; + sumsq = ymax * ymax * ( 1.0 + ( ymin / ymax ) * ( ymin / ymax ) ); + } + else + { + scale = 1.0 / scale_sml; + sumsq = sum_sml; + } + } + else + { + // If all values are mid-range: + scale = 1.0; + sumsq = sum_med; + } + + *norm = scale * sqrt( sumsq ); + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + + return; +} diff --git a/kernels/zen4/1/bli_scal2v_zen_int_avx512.c b/kernels/zen4/1/bli_scal2v_zen_int_avx512.c new file mode 100644 index 0000000000..c28c3af7db --- /dev/null +++ b/kernels/zen4/1/bli_scal2v_zen_int_avx512.c @@ -0,0 +1,234 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +// This kernel performs y := alpha * conjx(x) +void bli_dscal2v_zen_int_avx512 + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, return early. + if ( bli_zero_dim1( n ) ) + return; + + // Redirecting to DSETV, if alpha is 0 + if ( PASTEMAC( d, eq0 )( *alpha ) ) + { + double *zero = PASTEMAC( d, 0 ); + + bli_dsetv_zen_int_avx512 + ( + BLIS_NO_CONJUGATE, + n, + zero, + y, incy, + cntx + ); + + return; + } + // Redirecting to DCOPYV, if alpha is 1 + else if ( PASTEMAC( d, eq1 )( *alpha ) ) + { + bli_dcopyv_zen4_asm_avx512 + ( + conjx, + n, + x, incx, + y, incy, + cntx + ); + + return; + } + + // Initializing the pointer aliases and iterator + dim_t i = 0; + double *x0 = x; + double *y0 = y; + + // Handling unit-strided inputs + if ( incx == 1 && incy == 1 ) + { + // Vectors to be used in the scal2v computation + __m512d x_vec[8], alphav; + + // Broadcasting alpha to a 512-bit register + alphav = _mm512_set1_pd( *alpha ); + + const dim_t n_elem_per_reg = 8; + + // Iterating in blocks of 64 elements + for ( ; ( i + 63 ) < n; i += 64 ) + { + // Loading X vector + x_vec[0] = _mm512_loadu_pd( x0 ); + x_vec[1] = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + x_vec[2] = _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ); + x_vec[3] = _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ); + + // Scaling X vector with alpha + x_vec[0] = _mm512_mul_pd( x_vec[0], alphav ); + x_vec[1] = _mm512_mul_pd( x_vec[1], alphav ); + x_vec[2] = _mm512_mul_pd( x_vec[2], alphav ); + x_vec[3] = _mm512_mul_pd( x_vec[3], alphav ); + + // Storing onto Y + _mm512_storeu_pd( y0, x_vec[0] ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, x_vec[1] ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, x_vec[2] ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, x_vec[3] ); + + // Loading X vector + x_vec[4] = _mm512_loadu_pd( x0 + 4 * n_elem_per_reg ); + x_vec[5] = _mm512_loadu_pd( x0 + 5 * n_elem_per_reg ); + x_vec[6] = _mm512_loadu_pd( x0 + 6 * n_elem_per_reg ); + x_vec[7] = _mm512_loadu_pd( x0 + 7 * n_elem_per_reg ); + + // Scaling X vector with alpha + x_vec[4] = _mm512_mul_pd( x_vec[4], alphav ); + x_vec[5] = _mm512_mul_pd( x_vec[5], alphav ); + x_vec[6] = _mm512_mul_pd( x_vec[6], alphav ); + x_vec[7] = _mm512_mul_pd( x_vec[7], alphav ); + + // Storing onto Y + _mm512_storeu_pd( y0 + 4 * n_elem_per_reg, x_vec[4] ); + _mm512_storeu_pd( y0 + 5 * n_elem_per_reg, x_vec[5] ); + _mm512_storeu_pd( y0 + 6 * n_elem_per_reg, x_vec[6] ); + _mm512_storeu_pd( y0 + 7 * n_elem_per_reg, x_vec[7] ); + + // Adjusting the pointers for the next iteration + x0 += 8 * n_elem_per_reg; + y0 += 8 * n_elem_per_reg; + } + + // Iterating in blocks of 32 elements + for ( ; ( i + 31 ) < n; i += 32 ) + { + // Loading X vector + x_vec[0] = _mm512_loadu_pd( x0 ); + x_vec[1] = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + x_vec[2] = _mm512_loadu_pd( x0 + 2 * n_elem_per_reg ); + x_vec[3] = _mm512_loadu_pd( x0 + 3 * n_elem_per_reg ); + + // Scaling X vector with alpha + x_vec[0] = _mm512_mul_pd( x_vec[0], alphav ); + x_vec[1] = _mm512_mul_pd( x_vec[1], alphav ); + x_vec[2] = _mm512_mul_pd( x_vec[2], alphav ); + x_vec[3] = _mm512_mul_pd( x_vec[3], alphav ); + + // Storing onto Y + _mm512_storeu_pd( y0, x_vec[0] ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, x_vec[1] ); + _mm512_storeu_pd( y0 + 2 * n_elem_per_reg, x_vec[2] ); + _mm512_storeu_pd( y0 + 3 * n_elem_per_reg, x_vec[3] ); + + // Adjusting the pointers for the next iteration + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Iterating in blocks of 16 elements + for ( ; ( i + 15 ) < n; i += 16 ) + { + // Loading X vector + x_vec[0] = _mm512_loadu_pd( x0 ); + x_vec[1] = _mm512_loadu_pd( x0 + 1 * n_elem_per_reg ); + + // Scaling X vector with alpha + x_vec[0] = _mm512_mul_pd( x_vec[0], alphav ); + x_vec[1] = _mm512_mul_pd( x_vec[1], alphav ); + + // Storing onto Y + _mm512_storeu_pd( y0, x_vec[0] ); + _mm512_storeu_pd( y0 + 1 * n_elem_per_reg, x_vec[1] ); + + // Adjusting the pointers for the next iteration + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Iterating in blocks of 8 elements + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading X vector + x_vec[0] = _mm512_loadu_pd( x0 ); + + // Scaling X vector with alpha + x_vec[0] = _mm512_mul_pd( x_vec[0], alphav ); + + // Storing onto Y + _mm512_storeu_pd( y0, x_vec[0] ); + + // Adjusting the pointers for the next iteration + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Handling the fringe case + if ( i < n ) + { + // Setting the mask for loading and storing the vectors + __mmask8 n_mask = (1 << ( n - i )) - 1; + + // Loading X vector + x_vec[0] = _mm512_maskz_loadu_pd( n_mask, x0 ); + + // Scaling X vector with alpha + x_vec[0] = _mm512_mul_pd( x_vec[0], alphav ); + + // Storing onto Y + _mm512_mask_storeu_pd( y0, n_mask, x_vec[0] ); + } + } + + else + { + // Handling fringe case or non-unit strides + for ( ; i < n; i += 1 ) + { + *y0 = (*alpha) * (*x0); + x0 += incx; + y0 += incy; + } + } +} diff --git a/kernels/zen4/1/bli_scalv_zen_int_avx512.c b/kernels/zen4/1/bli_scalv_zen_int_avx512.c index febd6aa8e9..a2143a5247 100644 --- a/kernels/zen4/1/bli_scalv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_scalv_zen_int_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,13 +61,14 @@ Deviation from BLAS -------------------- - None + Setv is used when alpha=0 unless a negative value of n is supplied. + This only occurs in calls from BLAS and CBLAS scal APIs. Undefined behaviour ------------------- - 1. The kernel results in undefined behaviour when n <= 0 and incx <= 1. The expectation - is that these are standard BLAS exceptions and should be handled in a higher layer. + None + */ void bli_sscalv_zen_int_avx512 ( @@ -78,6 +79,30 @@ void bli_sscalv_zen_int_avx512 cntx_t *restrict cntx ) { + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(s,eq0)( *alpha ) && n > 0 ) + { + float *zero = bli_s0; + if (cntx == NULL) cntx = bli_gks_query_cntx(); + ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_FLOAT, BLIS_SETV_KER, cntx); + + f + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx + ); + + return; + } + + dim_t n0 = bli_abs(n); + dim_t i = 0; float *restrict x0 = x; @@ -89,7 +114,7 @@ void bli_sscalv_zen_int_avx512 __m512 xv[8], alphav; alphav = _mm512_set1_ps(*alpha); - for (i = 0; (i + 127) < n; i += 128) + for (i = 0; (i + 127) < n0; i += 128) { // Loading the input values xv[0] = _mm512_loadu_ps(x0 + 0 * n_elem_per_reg); @@ -125,7 +150,7 @@ void bli_sscalv_zen_int_avx512 x0 += 8 * n_elem_per_reg; } - for (; (i + 63) < n; i += 64) + for (; (i + 63) < n0; i += 64) { // Loading the input values xv[0] = _mm512_loadu_ps(x0 + 0 * n_elem_per_reg); @@ -147,7 +172,7 @@ void bli_sscalv_zen_int_avx512 x0 += 4 * n_elem_per_reg; } - for (; (i + 31) < n; i += 32) + for (; (i + 31) < n0; i += 32) { // Loading the input values xv[0] = _mm512_loadu_ps(x0 + 0 * n_elem_per_reg); @@ -163,7 +188,7 @@ void bli_sscalv_zen_int_avx512 x0 += 2 * n_elem_per_reg; } - for (; (i + 15) < n; i += 16) + for (; (i + 15) < n0; i += 16) { // Loading the input values xv[0] = _mm512_loadu_ps(x0 + 0 * n_elem_per_reg); @@ -176,7 +201,7 @@ void bli_sscalv_zen_int_avx512 x0 += n_elem_per_reg; } - for (; (i + 7) < n; i += 8) + for (; (i + 7) < n0; i += 8) { // Loading the input values __m256 x_vec = _mm256_loadu_ps(x0); @@ -198,7 +223,7 @@ void bli_sscalv_zen_int_avx512 */ _mm256_zeroupper(); - for (; (i + 3) < n; i += 4) + for (; (i + 3) < n0; i += 4) { // Loading the input values __m128 x_vec = _mm_loadu_ps(x0); @@ -215,7 +240,7 @@ void bli_sscalv_zen_int_avx512 const float alphac = *alpha; - for (; i < n; ++i) + for (; i < n0; ++i) { *x0 *= alphac; @@ -252,15 +277,16 @@ void bli_sscalv_zen_int_avx512 Deviation from BLAS -------------------- - None + Setv is used when alpha=0 unless a negative value of n is supplied. + This only occurs in calls from BLAS and CBLAS scal APIs. Undefined behaviour ------------------- - 1. The kernel results in undefined behaviour when n <= 0 and incx <= 1. The expectation - is that these are standard BLAS exceptions and should be handled in a higher layer. + None + */ -void bli_dscalv_zen_int_avx512 +BLIS_EXPORT_BLIS void bli_dscalv_zen_int_avx512 ( conj_t conjalpha, dim_t n, @@ -270,11 +296,10 @@ void bli_dscalv_zen_int_avx512 ) { // If the vector dimension is zero, or if alpha is unit, return early. - if (bli_zero_dim1(n) || PASTEMAC(d, eq1)(*alpha)) - return; + if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; - // If alpha is zero, use setv. - if (PASTEMAC(d, eq0)(*alpha)) + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(d,eq0)( *alpha ) && n > 0 ) { double *zero = bli_d0; if (cntx == NULL) cntx = bli_gks_query_cntx(); @@ -292,6 +317,8 @@ void bli_dscalv_zen_int_avx512 return; } + dim_t n0 = bli_abs(n); + dim_t i = 0; double *restrict x0; @@ -307,7 +334,7 @@ void bli_dscalv_zen_int_avx512 alphav = _mm512_set1_pd(*alpha); __m512d xv[8]; - for (i = 0; (i + 63) < n; i += 64) + for (i = 0; (i + 63) < n0; i += 64) { // Loading the input values xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); @@ -343,7 +370,7 @@ void bli_dscalv_zen_int_avx512 x0 += 8 * n_elem_per_reg; } - for (; (i + 31) < n; i += 32) + for (; (i + 31) < n0; i += 32) { // Loading the input values xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); @@ -365,7 +392,7 @@ void bli_dscalv_zen_int_avx512 x0 += 4 * n_elem_per_reg; } - for (; (i + 15) < n; i += 16) + for (; (i + 15) < n0; i += 16) { // Loading the input values xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); @@ -381,7 +408,7 @@ void bli_dscalv_zen_int_avx512 x0 += 2 * n_elem_per_reg; } - for (; (i + 7) < n; i += 8) + for (; (i + 7) < n0; i += 8) { // Loading the input values xv[0] = _mm512_loadu_pd(x0 + 0 * n_elem_per_reg); @@ -394,7 +421,7 @@ void bli_dscalv_zen_int_avx512 x0 += n_elem_per_reg; } - for (; (i + 3) < n; i += 4) + for (; (i + 3) < n0; i += 4) { // Loading the input values __m256d x_vec = _mm256_loadu_pd(x0); @@ -416,7 +443,7 @@ void bli_dscalv_zen_int_avx512 */ _mm256_zeroupper(); - for (; (i + 1) < n; i += 2) + for (; (i + 1) < n0; i += 2) { // Loading the input values __m128d x_vec = _mm_loadu_pd(x0); @@ -433,7 +460,7 @@ void bli_dscalv_zen_int_avx512 const double alphac = *alpha; - for (; i < n; ++i) + for (; i < n0; ++i) { *x0 *= alphac; @@ -468,13 +495,14 @@ void bli_dscalv_zen_int_avx512 Deviation from BLAS -------------------- - None + Setv is used when alpha=0 unless a negative value of n is supplied. + This only occurs in calls from BLAS and CBLAS scal APIs. Undefined behaviour ------------------- - 1. The kernel results in undefined behaviour when n <= 0 and incx <= 1. The expectation - is that these are standard BLAS exceptions and should be handled in a higher layer. + None + */ void bli_zdscalv_zen_int_avx512 ( @@ -491,6 +519,31 @@ void bli_zdscalv_zen_int_avx512 alpha is passed as double complex to adhere to function pointer definition in BLIS */ + + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(z,eq0)( *alpha ) && n > 0 ) + { + // Expert interface of setv is invoked when alpha is zero + dcomplex *zero = bli_z0; + + /* When alpha is zero all the element in x are set to zero */ + PASTEMAC2(z, setv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx, + NULL); + + return; + } + + dim_t n0 = bli_abs(n); + const double alphac = (*alpha).real; dim_t i = 0; @@ -504,7 +557,7 @@ void bli_zdscalv_zen_int_avx512 alphav = _mm512_set1_pd(alphac); - for (; (i + 15) < n; i += 16) + for (; (i + 15) < n0; i += 16) { xv[0] = _mm512_loadu_pd(x0); xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); @@ -524,7 +577,7 @@ void bli_zdscalv_zen_int_avx512 x0 += 4 * n_elem_per_reg; } - for (; (i + 7) < n; i += 8) + for (; (i + 7) < n0; i += 8) { xv[0] = _mm512_loadu_pd(x0); xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); @@ -538,7 +591,7 @@ void bli_zdscalv_zen_int_avx512 x0 += 2 * n_elem_per_reg; } - for (; (i + 3) < n; i += 4) + for (; (i + 3) < n0; i += 4) { xv[0] = _mm512_loadu_pd(x0); @@ -549,7 +602,7 @@ void bli_zdscalv_zen_int_avx512 x0 += n_elem_per_reg; } - for (; (i + 1) < n; i += 2) + for (; (i + 1) < n0; i += 2) { __m256d xv = _mm256_loadu_pd(x0); @@ -576,7 +629,7 @@ void bli_zdscalv_zen_int_avx512 alpha_reg = _mm_set1_pd((*alpha).real); - for (; i < n; ++i) + for (; i < n0; ++i) { x_vec = _mm_loadu_pd(x0); @@ -587,3 +640,704 @@ void bli_zdscalv_zen_int_avx512 x0 += 2 * incx; } } + + +#define MICRO_OP( r0, r1, r2, r3 ) \ + /** + * Loading 8 scomplex (16 float) elements from x to each zmm register. + * xv[0] = x0R x0I x1R x1I x2R x2I x3R x3I ... + */ \ + xv[r0] = _mm512_loadu_ps( x0 + r0*n_elem_per_reg ); \ + xv[r1] = _mm512_loadu_ps( x0 + r1*n_elem_per_reg ); \ + xv[r2] = _mm512_loadu_ps( x0 + r2*n_elem_per_reg ); \ + xv[r3] = _mm512_loadu_ps( x0 + r3*n_elem_per_reg ); \ + \ + /** + * Using itermediate ZMM register to interchange real and imaginary + * values of each element in xv register. + * inter[0] = x0I x0R x1I x1R x2I x2R x3I x3R... + */ \ + inter[r0] = _mm512_permute_ps( xv[r0], 0xB1 ); \ + inter[r1] = _mm512_permute_ps( xv[r1], 0xB1 ); \ + inter[r2] = _mm512_permute_ps( xv[r2], 0xB1 ); \ + inter[r3] = _mm512_permute_ps( xv[r3], 0xB1 ); \ + \ + /** + * Scaling intermediate vector with imaginary part of alpha. + * inter[0] = inter[0] * alphaI + * = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + */ \ + \ + inter[r0] = _mm512_mul_ps( inter[r0], alphaIv ); \ + inter[r1] = _mm512_mul_ps( inter[r1], alphaIv ); \ + inter[r2] = _mm512_mul_ps( inter[r2], alphaIv ); \ + inter[r3] = _mm512_mul_ps( inter[r3], alphaIv ); \ + \ + /** + * Scaling xv with real part of alpha and doing alternatively sub-add of + * the scaled intermediate register. The fmaddsub operation will + * alternatively add and subtract elements in inter[0] from alphaRv*xv[0]. + * xv[0] = xv[0] * alphaR -/+ inter[0] + * = x0R*alphaR - x0I*alphaI x0I*alphaR + x0R*alphaI + * x1R*alphaR - x1I*alphaI x1I*alphaR + x1R*alphaI ... + */ \ + xv[r0] = _mm512_fmaddsub_ps( alphaRv, xv[r0], inter[r0] ); \ + xv[r1] = _mm512_fmaddsub_ps( alphaRv, xv[r1], inter[r1] ); \ + xv[r2] = _mm512_fmaddsub_ps( alphaRv, xv[r2], inter[r2] ); \ + xv[r3] = _mm512_fmaddsub_ps( alphaRv, xv[r3], inter[r3] ); \ + \ + /** + * Storing the scaled vector back to x0. + */ \ + _mm512_storeu_ps( x0 + r0*n_elem_per_reg, xv[r0] ); \ + _mm512_storeu_ps( x0 + r1*n_elem_per_reg, xv[r1] ); \ + _mm512_storeu_ps( x0 + r2*n_elem_per_reg, xv[r2] ); \ + _mm512_storeu_ps( x0 + r3*n_elem_per_reg, xv[r3] ); + +/* + Functionality + ------------- + + This function scales a single complex vector by an element of the + type single complex. + + x := conjalpha(alpha) * x + + Function Signature + ------------------- + + * 'conjalpha' - Variable specified if alpha needs to be conjugated + * 'n' - Length of the array passed + * 'alpha' - Pointer to the element by which the vector is to be scaled + * 'x' - Single complex pointer pointing to an array + * 'incx' - Stride to point to the next element in the array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + 1. The kernel invokes SETV when alpha scalar is zero and explicitly sets all + elements to zero thus, not propagating any NaNs/Infs. + + Undefined behaviour + ------------------- + + None + +*/ +void bli_cscalv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if ( PASTEMAC(c,eq0)( *alpha ) && n > 0 ) + { + // Expert interface of setv is invoked when alpha is zero + scomplex *zero = bli_c0; + + /* When alpha is zero all the element in x are set to zero */ + PASTEMAC2(c, setv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx, + NULL + ); + + return; + } + + dim_t n0 = bli_abs(n); + + dim_t i = 0; + scomplex alpha_conj; + float* restrict x0 = (float*) x; + + // Performs conjugation of alpha based on conjalpha. + PASTEMAC(c,copycjs)( conjalpha, *alpha, alpha_conj ); + + const float alphaR = alpha_conj.real; + const float alphaI = alpha_conj.imag; + + if ( incx == 1 ) + { + // number of elements per register. + const dim_t n_elem_per_reg = 16; + + __m512 alphaRv, alphaIv; + + // Broadcast real and imaginary values of alpha. + alphaRv = _mm512_set1_ps( alphaR ); + alphaIv = _mm512_set1_ps( alphaI ); + + /** + * General Algorithm: + * + * Broadcasting real and imaginary parts of alpha scalar to separate + * zmm registers, alphaRv and alphaIv, respectively. + * alphaRv = alphaR alphaR alphaR alphaR ... + * alphaIv = alphaI alphaI alphaI alphaI ... + * + * Loading 8 scomplex (16 float) elements from x to each zmm register. + * xv[0] = x0R x0I x1R x1I x2R x2I x3R x3I ... + * + * Using itermediate ZMM register to interchange real and imaginary + * values of each element in xv register. + * inter[0] = x0I x0R x1I x1R x2I x2R x3I x3R... + * + * Scaling the intermediate register with imaginary part of alpha. + * inter[0] = inter[0] * alphaI + * = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + * + * Scaling xv with real part of alpha and doing alternatively sub-add of + * the scaled intermediate register. + * xv[0] = xv[0] * alphaR -/+ inter[0] + * = x0R*alphaR - x0I*alphaI x0I*alphaR + x0R*alphaI + * x1R*alphaR - x1I*alphaI x1I*alphaR + x1R*alphaI ... + */ + + // Processing 96 scomplex elements (192 floats) per iteration + for ( ; (i + 95) < n0; i += 96 ) + { + __m512 xv[12], inter[12]; + + MICRO_OP( 0, 1, 2, 3 ) + + MICRO_OP( 4, 5, 6, 7 ) + + MICRO_OP( 8, 9, 10, 11 ) + + // Incrementing x0 by 12*n_elem_per_reg, 192 floats + // or 96 scomplex elements. + x0 += 12 * n_elem_per_reg; + } + + // Processing 64 scomplex elements (128 floats) per iteration + for ( ; (i + 63) < n0; i += 64 ) + { + __m512 xv[8], inter[8]; + + MICRO_OP( 0, 1, 2, 3 ) + + MICRO_OP( 4, 5, 6, 7 ) + + // Incrementing x0 by 8*n_elem_per_reg, 128 floats + // or 64 scomplex elements. + x0 += 8 * n_elem_per_reg; + } + + // Processing 32 scomplex elements (64 floats) per iteration + for ( ; (i + 31) < n0; i += 32 ) + { + __m512 xv[4], inter[4]; + + MICRO_OP( 0, 1, 2, 3 ) + + // Incrementing x0 by 4*n_elem_per_reg, 64 floats + // or 32 scomplex elements. + x0 += 4 * n_elem_per_reg; + } + + // Processing 16 scomplex elements (32 floats) per iteration + for ( ; (i + 15) < n0; i += 16 ) + { + __m512 xv[2], inter[2]; + + // Loading 8 scomplex (16 float) elements from x to each + // zmm register. + // xv[0] = x0R x0I x1R x1I x2R x2I x3R x3I ... + xv[0] = _mm512_loadu_ps( x0 ); + xv[1] = _mm512_loadu_ps( x0 + 1*n_elem_per_reg ); + + // Permuting xv and storing into intermediate vector. + // inter[0] = x0I x0R x1I x1R x2I x2R x3I x3R... + inter[0] = _mm512_permute_ps( xv[0], 0xB1 ); + inter[1] = _mm512_permute_ps( xv[1], 0xB1 ); + + // Scaling intermediate vector with imaginary part of alpha. + // inter[0] = inter[0] * alphaI + // = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + inter[0] = _mm512_mul_ps( inter[0], alphaIv ); + inter[1] = _mm512_mul_ps( inter[1], alphaIv ); + + // Performing the fmaddsub operation to get resultant x scaled by + // alpha. The fmaddsub operation will alternatively add and subtract + // elements in inter[0] from alphaRv*xv[0]. + // xv[0] = xv[0] * alphaR -/+ inter[0] + // = x0R*alphaR - x0I*alphaI x0I*alphaR + x0R*alphaI + // x1R*alphaR - x1I*alphaI x1I*alphaR + x1R*alphaI ... + xv[0] = _mm512_fmaddsub_ps( alphaRv, xv[0], inter[0] ); + xv[1] = _mm512_fmaddsub_ps( alphaRv, xv[1], inter[1] ); + + // Storing the scaled vector back to x0. + _mm512_storeu_ps( x0, xv[0] ); + _mm512_storeu_ps( x0 + 1*n_elem_per_reg, xv[1] ); + + // Incrementing x0 by 2*n_elem_per_reg, 32 floats + // or 16 scomplex elements. + x0 += 2 * n_elem_per_reg; + } + + // Processing 8 scomplex elements (16 floats) per iteration + for ( ; (i + 7) < n0; i += 8 ) + { + __m512 xv[1], inter[1]; + + // Loading 8 scomplex (16 float) elements from x to each + // zmm register. + // xv[0] = x0R x0I x1R x1I x2R x2I x3R x3I ... + xv[0] = _mm512_loadu_ps( x0 ); + + // Permuting xv and storing into intermediate zmm register. + // inter[0] = x0I x0R x1I x1R x2I x2R x3I x3R... + inter[0] = _mm512_permute_ps( xv[0], 0xB1 ); + + // Scaling intermediate register with imaginary part of alpha. + // inter[0] = inter[0] * alphaI + // = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + inter[0] = _mm512_mul_ps( inter[0], alphaIv ); + + // Performing the fmaddsub operation to get resultant x scaled by + // alpha. The fmaddsub operation will alternatively add and subtract + // elements in inter[0] from alphaRv*xv[0]. + // xv[0] = xv[0] * alphaR -/+ inter[0] + // = x0R*alphaR - x0I*alphaI x0I*alphaR + x0R*alphaI + // x1R*alphaR - x1I*alphaI x1I*alphaR + x1R*alphaI ... + xv[0] = _mm512_fmaddsub_ps( alphaRv, xv[0], inter[0] ); + + // Storing the scaled vector back to x0. + _mm512_storeu_ps( x0, xv[0] ); + + // Incrementing x0 by n_elem_per_reg, 16 floats + // or 8 scomplex elements. + x0 += n_elem_per_reg; + } + + // Processing remaining elements, if any. + if ( i < n0 ) + { + // Setting the mask bit based on remaining elements. + // Since each scomplex element corresponds to 2 floats, + // we need to load and store 2*(n0-i) elements. + + __mmask16 mask = ( 1 << ( 2 * ( n0 - i ) ) ) - 1; + + __m512 xv, temp; + + xv = _mm512_maskz_loadu_ps( mask, x0 ); + + temp = _mm512_permute_ps( xv, 0xB1 ); + + temp = _mm512_mul_ps( alphaIv, temp ); + + xv = _mm512_fmaddsub_ps( alphaRv, xv, temp ); + + _mm512_mask_storeu_ps( x0, mask, xv ); + } + } + else // if ( incx != 1 ) + { + const float alphaR = alpha_conj.real; + const float alphaI = alpha_conj.imag; + + float x0R, x0I; + for (; i < n0; ++i) + { + x0R = *(x0); + x0I = *(x0 + 1); + + *(x0) = x0R * alphaR - x0I * alphaI; + *(x0 + 1) = x0R * alphaI + x0I * alphaR; + + x0 += 2*incx; + } + } +} + +/* + Functionality + ------------- + + This function scales a double complex vector by an element of the + type double complex. + + x := conjalpha(alpha) * x + + Function Signature + ------------------- + + * 'conjalpha' - Variable specified if alpha needs to be conjugated + * 'n' - Length of the array passed + * 'alpha' - Pointer to the element by which the vector is to be scaled + * 'x' - Double complex pointer pointing to an array + * 'incx' - Stride to point to the next element in the array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + Setv is used when alpha=0 unless a negative value of n is supplied. + This only occurs in calls from BLAS and CBLAS scal APIs. + + Undefined behaviour + ------------------- + + None + +*/ +void bli_zscalv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv if not called from BLAS scal itself (indicated by n being negative). + if (PASTEMAC(z,eq0)( *alpha ) && n > 0 ) + { + // Expert interface of setv is invoked when alpha is zero + dcomplex *zero = bli_z0; + + /* When alpha is zero all the element in x are set to zero */ + PASTEMAC2(z, setv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx, + NULL); + + return; + } + + dim_t n0 = bli_abs(n); + + dim_t i = 0; + dcomplex alpha_conj; + double *restrict x0 = (double *)x; + + // Performs conjugation of alpha based on conjalpha + PASTEMAC(z, copycjs)(conjalpha, *alpha, alpha_conj) + + const double alphaR = alpha_conj.real; + const double alphaI = alpha_conj.imag; + + if (incx == 1) + { + __m512d alphaRv, alphaIv; + const dim_t n_elem_per_reg = 8; // number of elements per register + + // Broadcast real and imaginary values of alpha to separate registers. + // alphaRv = alphaR alphaR alphaR alphaR ... + // alphaIv = alphaI alphaI alphaI alphaI ... + alphaRv = _mm512_set1_pd(alphaR); + alphaIv = _mm512_set1_pd(alphaI); + + /** + * General Algorithm: + * + * alphaRv = alphaR alphaR alphaR alphaR ... + * alphaIv = alphaI alphaI alphaI alphaI ... + * + * xv[0] = x0R x0I x1R x1I ... + * temp[0] = x0I x0R x1I x1R ... + * temp[0] = temp[0] * xv[0] + * = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + * xv[0] = xv[0] * alphaR + temp[0] + * = x0R*alphaR + x0I*alphaI x0I*alphaR + x0R*alphaI + * x1R*alphaR + x1I*alphaI x1I*alphaR + x1R*alphaI ... + */ + + // Processing 48 dcomplex elements per iteration. + for (; (i + 47) < n0; i += 48) + { + __m512d xv[12], temp[12]; + + // Load elements from x vector. + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + // Operation: xv -> xv' + // xv = y0R y0I y1R y1I ... + // xv' = y0I y0R y1I y1R ... + temp[0] = _mm512_permute_pd(xv[0], 0x55); + temp[1] = _mm512_permute_pd(xv[1], 0x55); + temp[2] = _mm512_permute_pd(xv[2], 0x55); + temp[3] = _mm512_permute_pd(xv[3], 0x55); + + // Operation: temp = temp * alphaIv + // temp = x0I*alphaI x0R*alphaI x1I*alphaI x1R*alphaI ... + temp[0] = _mm512_mul_pd(alphaIv, temp[0]); + temp[1] = _mm512_mul_pd(alphaIv, temp[1]); + temp[2] = _mm512_mul_pd(alphaIv, temp[2]); + temp[3] = _mm512_mul_pd(alphaIv, temp[3]); + + // Operation: xv[0] = xv[0] * alphaR + temp[0] + // xv[0] = x0R*alphaR + x0I*alphaI x0I*alphaR + x0R*alphaI + // x1R*alphaR + x1I*alphaI x1I*alphaR + x1R*alphaI ... + xv[0] = _mm512_fmaddsub_pd(alphaRv, xv[0], temp[0]); + xv[1] = _mm512_fmaddsub_pd(alphaRv, xv[1], temp[1]); + xv[2] = _mm512_fmaddsub_pd(alphaRv, xv[2], temp[2]); + xv[3] = _mm512_fmaddsub_pd(alphaRv, xv[3], temp[3]); + + // Store result to memory. + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + _mm512_storeu_pd(x0 + 2 * n_elem_per_reg, xv[2]); + _mm512_storeu_pd(x0 + 3 * n_elem_per_reg, xv[3]); + + xv[4] = _mm512_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm512_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm512_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm512_loadu_pd(x0 + 7 * n_elem_per_reg); + + temp[4] = _mm512_permute_pd(xv[4], 0x55); + temp[5] = _mm512_permute_pd(xv[5], 0x55); + temp[6] = _mm512_permute_pd(xv[6], 0x55); + temp[7] = _mm512_permute_pd(xv[7], 0x55); + + temp[4] = _mm512_mul_pd(alphaIv, temp[4]); + temp[5] = _mm512_mul_pd(alphaIv, temp[5]); + temp[6] = _mm512_mul_pd(alphaIv, temp[6]); + temp[7] = _mm512_mul_pd(alphaIv, temp[7]); + + xv[4] = _mm512_fmaddsub_pd(alphaRv, xv[4], temp[4]); + xv[5] = _mm512_fmaddsub_pd(alphaRv, xv[5], temp[5]); + xv[6] = _mm512_fmaddsub_pd(alphaRv, xv[6], temp[6]); + xv[7] = _mm512_fmaddsub_pd(alphaRv, xv[7], temp[7]); + + _mm512_storeu_pd(x0 + 4 * n_elem_per_reg, xv[4]); + _mm512_storeu_pd(x0 + 5 * n_elem_per_reg, xv[5]); + _mm512_storeu_pd(x0 + 6 * n_elem_per_reg, xv[6]); + _mm512_storeu_pd(x0 + 7 * n_elem_per_reg, xv[7]); + + xv[8] = _mm512_loadu_pd(x0 + 8 * n_elem_per_reg); + xv[9] = _mm512_loadu_pd(x0 + 9 * n_elem_per_reg); + xv[10] = _mm512_loadu_pd(x0 + 10 * n_elem_per_reg); + xv[11] = _mm512_loadu_pd(x0 + 11 * n_elem_per_reg); + + temp[8] = _mm512_permute_pd(xv[8], 0x55); + temp[9] = _mm512_permute_pd(xv[9], 0x55); + temp[10] = _mm512_permute_pd(xv[10], 0x55); + temp[11] = _mm512_permute_pd(xv[11], 0x55); + + temp[8] = _mm512_mul_pd(alphaIv, temp[8]); + temp[9] = _mm512_mul_pd(alphaIv, temp[9]); + temp[10] = _mm512_mul_pd(alphaIv, temp[10]); + temp[11] = _mm512_mul_pd(alphaIv, temp[11]); + + xv[8] = _mm512_fmaddsub_pd(alphaRv, xv[8], temp[8]); + xv[9] = _mm512_fmaddsub_pd(alphaRv, xv[9], temp[9]); + xv[10] = _mm512_fmaddsub_pd(alphaRv, xv[10], temp[10]); + xv[11] = _mm512_fmaddsub_pd(alphaRv, xv[11], temp[11]); + + _mm512_storeu_pd(x0 + 8 * n_elem_per_reg, xv[8]); + _mm512_storeu_pd(x0 + 9 * n_elem_per_reg, xv[9]); + _mm512_storeu_pd(x0 + 10 * n_elem_per_reg, xv[10]); + _mm512_storeu_pd(x0 + 11 * n_elem_per_reg, xv[11]); + + // Increment x0 vector pointer. + x0 += 12 * n_elem_per_reg; + } + + // Processing 32 dcomplex elements per iteration. + for (; (i + 31) < n0; i += 32) + { + __m512d xv[8], temp[8]; + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + temp[0] = _mm512_permute_pd(xv[0], 0x55); + temp[1] = _mm512_permute_pd(xv[1], 0x55); + temp[2] = _mm512_permute_pd(xv[2], 0x55); + temp[3] = _mm512_permute_pd(xv[3], 0x55); + + temp[0] = _mm512_mul_pd(alphaIv, temp[0]); + temp[1] = _mm512_mul_pd(alphaIv, temp[1]); + temp[2] = _mm512_mul_pd(alphaIv, temp[2]); + temp[3] = _mm512_mul_pd(alphaIv, temp[3]); + + xv[0] = _mm512_fmaddsub_pd(alphaRv, xv[0], temp[0]); + xv[1] = _mm512_fmaddsub_pd(alphaRv, xv[1], temp[1]); + xv[2] = _mm512_fmaddsub_pd(alphaRv, xv[2], temp[2]); + xv[3] = _mm512_fmaddsub_pd(alphaRv, xv[3], temp[3]); + + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + _mm512_storeu_pd(x0 + 2 * n_elem_per_reg, xv[2]); + _mm512_storeu_pd(x0 + 3 * n_elem_per_reg, xv[3]); + + xv[4] = _mm512_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm512_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm512_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm512_loadu_pd(x0 + 7 * n_elem_per_reg); + + temp[4] = _mm512_permute_pd(xv[4], 0x55); + temp[5] = _mm512_permute_pd(xv[5], 0x55); + temp[6] = _mm512_permute_pd(xv[6], 0x55); + temp[7] = _mm512_permute_pd(xv[7], 0x55); + + temp[4] = _mm512_mul_pd(alphaIv, temp[4]); + temp[5] = _mm512_mul_pd(alphaIv, temp[5]); + temp[6] = _mm512_mul_pd(alphaIv, temp[6]); + temp[7] = _mm512_mul_pd(alphaIv, temp[7]); + + xv[4] = _mm512_fmaddsub_pd(alphaRv, xv[4], temp[4]); + xv[5] = _mm512_fmaddsub_pd(alphaRv, xv[5], temp[5]); + xv[6] = _mm512_fmaddsub_pd(alphaRv, xv[6], temp[6]); + xv[7] = _mm512_fmaddsub_pd(alphaRv, xv[7], temp[7]); + + _mm512_storeu_pd(x0 + 4 * n_elem_per_reg, xv[4]); + _mm512_storeu_pd(x0 + 5 * n_elem_per_reg, xv[5]); + _mm512_storeu_pd(x0 + 6 * n_elem_per_reg, xv[6]); + _mm512_storeu_pd(x0 + 7 * n_elem_per_reg, xv[7]); + + x0 += 8 * n_elem_per_reg; + } + + // Processing 16 dcomplex elements per iteration. + for (; (i + 15) < n0; i += 16) + { + __m512d xv[4], temp[4]; + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + temp[0] = _mm512_permute_pd(xv[0], 0x55); + temp[1] = _mm512_permute_pd(xv[1], 0x55); + temp[2] = _mm512_permute_pd(xv[2], 0x55); + temp[3] = _mm512_permute_pd(xv[3], 0x55); + + temp[0] = _mm512_mul_pd(alphaIv, temp[0]); + temp[1] = _mm512_mul_pd(alphaIv, temp[1]); + temp[2] = _mm512_mul_pd(alphaIv, temp[2]); + temp[3] = _mm512_mul_pd(alphaIv, temp[3]); + + xv[0] = _mm512_fmaddsub_pd(alphaRv, xv[0], temp[0]); + xv[1] = _mm512_fmaddsub_pd(alphaRv, xv[1], temp[1]); + xv[2] = _mm512_fmaddsub_pd(alphaRv, xv[2], temp[2]); + xv[3] = _mm512_fmaddsub_pd(alphaRv, xv[3], temp[3]); + + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + _mm512_storeu_pd(x0 + 2 * n_elem_per_reg, xv[2]); + _mm512_storeu_pd(x0 + 3 * n_elem_per_reg, xv[3]); + + x0 += 4 * n_elem_per_reg; + } + + // Processing 8 dcomplex elements per iteration. + for (; (i + 7) < n0; i += 8) + { + __m512d xv[2], temp[2]; + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + + temp[0] = _mm512_permute_pd(xv[0], 0x55); + temp[1] = _mm512_permute_pd(xv[1], 0x55); + + temp[0] = _mm512_mul_pd(alphaIv, temp[0]); + temp[1] = _mm512_mul_pd(alphaIv, temp[1]); + + xv[0] = _mm512_fmaddsub_pd(alphaRv, xv[0], temp[0]); + xv[1] = _mm512_fmaddsub_pd(alphaRv, xv[1], temp[1]); + + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + + x0 += 2 * n_elem_per_reg; + } + + // Processing 4 dcomplex elements per iteration. + for (; (i + 3) < n0; i += 4) + { + __m512d xv, temp; + xv = _mm512_loadu_pd(x0); + + temp = _mm512_permute_pd(xv, 0x55); + + temp = _mm512_mul_pd(alphaIv, temp); + + xv = _mm512_fmaddsub_pd(alphaRv, xv, temp); + + _mm512_storeu_pd(x0, xv); + + x0 += n_elem_per_reg; + } + + // Processing the remainder elements. + if( i < n0 ) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(n0-i) elements. + + __mmask8 mask = ( 1 << ( 2 * ( n0 - i ) ) ) - 1; + + __m512d xv, temp, zero; + zero = _mm512_setzero_pd(); + + xv = _mm512_mask_loadu_pd( zero, mask, x0 ); + + temp = _mm512_permute_pd( xv, 0x55 ); + + temp = _mm512_mul_pd( alphaIv, temp ); + + xv = _mm512_fmaddsub_pd( alphaRv, xv, temp ); + + _mm512_mask_storeu_pd( x0, mask, xv ); + } + } + else // Non-unit increment. + { + __m128d alphaRv, alphaIv, x_vec, temp; + + alphaRv = _mm_loaddup_pd(&alphaR); + alphaIv = _mm_loaddup_pd(&alphaI); + + for (; i < n0; ++i) + { + x_vec = _mm_loadu_pd(x0); + + temp = _mm_shuffle_pd(x_vec, x_vec, 0x1); + + temp = _mm_mul_pd(alphaIv, temp); + x_vec = _mm_fmaddsub_pd(alphaRv, x_vec, temp); + + _mm_storeu_pd(x0, x_vec); + + x0 += 2 * incx; + } + } +} diff --git a/kernels/zen4/1/bli_setv_zen_int_avx512.c b/kernels/zen4/1/bli_setv_zen_int_avx512.c new file mode 100644 index 0000000000..ba9222edb3 --- /dev/null +++ b/kernels/zen4/1/bli_setv_zen_int_avx512.c @@ -0,0 +1,466 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +// ----------------------------------------------------------------------------- + +void bli_ssetv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // Declaring and initializing local variables and pointers + const dim_t num_elem_per_reg = 16; + dim_t i = 0; + float *x0 = x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + // Handling unit strides + if ( incx == 1 ) + { + __m512 alphav; + + // Broadcast alpha to the register + alphav = _mm512_set1_ps( *alpha ); + + // The condition n & ~0x1FF => n & 0xFFFFFE00 + // This sets the lower 9 bits to 0 and results in multiples of 512 + // Thus, we iterate in blocks of 512 elements + // Fringe loops have similar conditions to set their masks(256, 128, ...) + for ( i = 0; i < (n & (~0x1FF)); i += 512 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 15, alphav); + + _mm512_storeu_ps(x0 + num_elem_per_reg * 16, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 17, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 18, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 19, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 20, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 21, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 22, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 23, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 24, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 25, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 26, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 27, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 28, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 29, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 30, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 31, alphav); + + x0 += 512; + } + for ( ; i < (n & (~0xFF)); i += 256 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 15, alphav); + + x0 += 256; + } + for ( ; i < (n & (~0x7F)); i += 128 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 7, alphav); + + x0 += 128; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 3, alphav); + + x0 += 64; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_ps(x0 + num_elem_per_reg * 1, alphav); + + x0 += 32; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm512_storeu_ps(x0 + num_elem_per_reg * 0, alphav); + x0 += 16; + } + if (i < n) + { + // Setting the mask register to store the remaining elements + __mmask16 m_mask = ( 1 << (n - i)) - 1; + _mm512_mask_storeu_ps(x0 + num_elem_per_reg * 0, m_mask, alphav); + } + } + else + { + // Scalar loop to handle non-unit strides + for ( dim_t i = 0; i < n; ++i ) + { + *x0 = *alpha; + x0 += incx; + } + } +} + +void bli_dsetv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // Declaring and initializing local variables and pointers + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + double *x0 = x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + if ( incx == 1 ) + { + __m512d alphav; + + // Broadcast alpha to the register + alphav = _mm512_set1_pd( *alpha ); + + // The condition n & ~0xFF => n & 0xFFFFFF00 + // This sets the lower 8 bits to 0 and results in multiples of 256 + // Thus, we iterate in blocks of 256 elements + // Fringe loops have similar conditions to set their masks(128, 64, ...) + for ( i = 0; i < (n & (~0xFF)); i += 256 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + _mm512_storeu_pd(x0 + num_elem_per_reg * 16, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 17, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 18, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 19, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 20, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 21, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 22, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 23, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 24, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 25, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 26, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 27, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 28, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 29, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 30, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 31, alphav); + + x0 += 256; + } + for ( ; i < (n & (~0x7F)); i += 128 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + x0 += 128; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + + x0 += 64; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + + x0 += 32; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + + x0 += 16; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + x0 += 8; + } + if (i < n) + { + __mmask8 m_mask = ( 1 << (n - i)) - 1; + _mm512_mask_storeu_pd(x0 + num_elem_per_reg * 0, m_mask, alphav); + } + } + else + { + // Scalar loop to handle non-unit-strides + for ( i = 0; i < n; ++i ) + { + *x0 = *alpha; + x0 += incx; + } + } +} + +void bli_zsetv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + // Declaring and initializing local variables and pointers + const dim_t num_elem_per_reg = 8; + dim_t i = 0; + double *x0 = (double *)x; + + // If the vector dimension is zero return early. + if ( bli_zero_dim1( n ) ) return; + + // Handle conjugation of alpha + if ( bli_is_conj( conjalpha ) ) alpha->imag = -alpha->imag; + + if ( incx == 1 ) + { + __m512d alphaRv, alphaIv; + __m512d alphav; + + // Broadcast alpha(real and imag) to the separate registers + alphaRv = _mm512_set1_pd((double)(alpha->real)); + alphaIv = _mm512_set1_pd((double)(alpha->imag)); + + // Unpack and store it in interleaved format + alphav = _mm512_unpacklo_pd(alphaRv, alphaIv); + + // The condition n & ~0x7F => n & 0xFFFFFE80 + // This sets the lower 7 bits to 0 and results in multiples of 128 + // Thus, we iterate in blocks of 128 elements + // Fringe loops have similar conditions to set their masks(64, 32, ...) + for ( ; i < (n & (~0x7F)); i += 128 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + _mm512_storeu_pd(x0 + num_elem_per_reg * 16, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 17, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 18, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 19, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 20, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 21, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 22, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 23, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 24, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 25, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 26, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 27, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 28, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 29, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 30, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 31, alphav); + + x0 += 256; + } + for ( ; i < (n & (~0x3F)); i += 64 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 8, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 9, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 10, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 11, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 12, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 13, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 14, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 15, alphav); + + x0 += 128; + } + for ( ; i < (n & (~0x1F)); i += 32 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 4, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 5, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 6, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 7, alphav); + + x0 += 64; + } + for ( ; i < (n & (~0x0F)); i += 16 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 2, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 3, alphav); + + x0 += 32; + } + for ( ; i < (n & (~0x07)); i += 8 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + _mm512_storeu_pd(x0 + num_elem_per_reg * 1, alphav); + + x0 += 16; + } + for ( ; i < (n & (~0x03)); i += 4 ) + { + _mm512_storeu_pd(x0 + num_elem_per_reg * 0, alphav); + x0 += 8; + } + if (i < n) + { + // Set the mask to load the remaining elements + // One double complex elements corresponds to two doubles in memory + __mmask8 m_mask = ( 1 << 2*(n - i)) - 1; + _mm512_mask_storeu_pd(x0 + num_elem_per_reg * 0, m_mask, alphav); + } + } + else + { + __m128d alphav; + alphav = _mm_loadu_pd((const double*)alpha); + + for( ; i < n; i += 1 ) + { + _mm_storeu_pd(x0, alphav); + x0 += 2 * incx; + } + } +} diff --git a/kernels/zen4/1f/bli_axpyf_zen_int_avx512.c b/kernels/zen4/1f/bli_axpyf_zen_int_avx512.c new file mode 100644 index 0000000000..02f894ef26 --- /dev/null +++ b/kernels/zen4/1f/bli_axpyf_zen_int_avx512.c @@ -0,0 +1,3290 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +#if defined __clang__ + #define UNROLL_LOOP_FULL() _Pragma("clang loop unroll(full)") +#elif defined __GNUC__ + #define UNROLL_LOOP_FULL() _Pragma("GCC unroll 32") +#else + #define UNROLL_LOOP_FULL() +#endif + +#define GENTFUNC_AXPYF(FUSE_FACTOR) \ + void PASTEMAC2(daxpyf_zen_int, FUSE_FACTOR, _avx512) \ + ( \ + conj_t conja, \ + conj_t conjx, \ + dim_t m, \ + dim_t b_n, \ + double* restrict alpha, \ + double* restrict a, inc_t inca, inc_t lda, \ + double* restrict x, inc_t incx, \ + double* restrict y0, inc_t incy, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t fuse_fac = FUSE_FACTOR; \ + const dim_t n_elem_per_reg = 8; \ + dim_t i = 0; \ + \ + __m512d chi[fuse_fac]; \ + __m512d av[1]; \ + __m512d yv[1]; \ + double* as[fuse_fac] __attribute__((aligned(64))); \ + double* y = y0; \ + \ + /* If either dimension is zero, or if alpha is zero, return early.*/ \ + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; \ + \ + /* If b_n is not equal to the fusing factor, then perform the entire + operation as a loop over axpyv. */ \ + if ( b_n != fuse_fac ) \ + { \ + daxpyv_ker_ft f = bli_daxpyv_zen_int_avx512; \ + \ + for ( i = 0; i < b_n; ++i ) \ + { \ + double* a1 = a + (0 )*inca + (i )*lda; \ + double* chi1 = x + (i )*incx; \ + double* y1 = y + (0 )*incy; \ + double alphavchi1; \ + \ + bli_dcopycjs( conjx, *chi1, alphavchi1 ); \ + bli_dscals( *alpha, alphavchi1 ); \ + \ + f \ + ( \ + conja, \ + m, \ + &alphavchi1, \ + a1, inca, \ + y1, incy, \ + cntx \ + ); \ + } \ + return; \ + } \ + \ + /* At this point, we know that b_n is exactly equal to the fusing factor.*/ \ + UNROLL_LOOP_FULL() \ + for (dim_t ii = 0; ii < fuse_fac; ++ii) \ + { \ + as[ii] = a + (ii * lda); \ + chi[ii] = _mm512_set1_pd( (*alpha) * (*(x + ii * incx)) ); \ + } \ + /* If there are vectorized iterations, perform them with vector + instructions.*/ \ + if ( inca == 1 && incy == 1 ) \ + { \ + __mmask8 m_mask; \ + m_mask = (1 << 8) - 1; \ + for ( ; i < m; i += 8) \ + { \ + if ( (m - i) < 8) m_mask = (1 << (m - i)) - 1; \ + yv[0] = _mm512_mask_loadu_pd( chi[0], m_mask, y ); \ + \ + UNROLL_LOOP_FULL() \ + for(int ii = 0; ii < fuse_fac; ++ii) \ + { \ + av[0] = _mm512_maskz_loadu_pd( m_mask, as[ii] ); \ + as[ii] += n_elem_per_reg; \ + yv[0] = _mm512_fmadd_pd( av[0], chi[ii], yv[0]); \ + } \ + _mm512_mask_storeu_pd( (double *)(y ), m_mask, yv[0] ); \ + \ + y += n_elem_per_reg; \ + } \ + } \ + else \ + { \ + double yc = *y; \ + double chi_s[fuse_fac]; \ + \ + UNROLL_LOOP_FULL() \ + for (dim_t ii = 0; ii < fuse_fac; ++ii) \ + { \ + chi_s[ii] = *(x + ii * incx) * *alpha; \ + } \ + for ( i = 0; (i + 0) < m ; ++i ) \ + { \ + yc = *y; \ + UNROLL_LOOP_FULL() \ + for (dim_t ii = 0 ; ii < fuse_fac; ++ii) \ + { \ + yc += chi_s[ii] * (*as[ii]); \ + as[ii] += inca; \ + } \ + *y = yc; \ + y += incy; \ + } \ + } \ +} \ + +// Generate axpyf kernels with various fuse factors. +GENTFUNC_AXPYF(6) +GENTFUNC_AXPYF(16) +GENTFUNC_AXPYF(32) + +// Wrapper for DAXPYF to redirect to kernels with lower fuse factors. +void bli_daxpyf_zen_int_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + dim_t fuse_fac = 32; + + if ( b_n < fuse_fac ) + { + double* a1 = a; + double* chi1 = x; + double* y1 = y; + double alphavchi1; + + if ( b_n >= 16 ) + { + bli_daxpyf_zen_int16_avx512 + ( + conja, + conjx, + m, + (dim_t)16, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 16*lda; + chi1 += 16*incx; + b_n -= 16; + } + + if ( b_n >= 8 ) + { + bli_daxpyf_zen_int8_avx512 + ( + conja, + conjx, + m, + (dim_t)8, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 8*lda; + chi1 += 8*incx; + b_n -= 8; + } + + if ( b_n >= 6 ) + { + bli_daxpyf_zen_int6_avx512 + ( + conja, + conjx, + m, + (dim_t)6, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 6*lda; + chi1 += 6*incx; + b_n -= 6; + } + + if ( b_n >= 4 ) + { + bli_daxpyf_zen_int4_avx512 + ( + conja, + conjx, + m, + (dim_t)4, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 4*lda; + chi1 += 4*incx; + b_n -= 4; + } + + if ( b_n >= 2 ) + { + bli_daxpyf_zen_int2_avx512 + ( + conja, + conjx, + m, + (dim_t)2, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 2*lda; + chi1 += 2*incx; + b_n -= 2; + } + + if ( b_n == 1 ) + { + daxpyv_ker_ft f = bli_daxpyv_zen_int_avx512; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + f + ( + conja, + m, + &alphavchi1, + a1, inca, + y1, incy, + cntx + ); + + return; + } + } + else if ( b_n > fuse_fac ) + { + daxpyv_ker_ft f = bli_daxpyv_zen_int_avx512; + + for ( dim_t i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alphavchi1; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + f + ( + conja, + m, + &alphavchi1, + a1, inca, + y1, incy, + cntx + ); + } + return; + } + else // if ( b_n == fuse_fac ) + { + bli_daxpyf_zen_int32_avx512 + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + y, incy, + cntx + ); + } +} + +#ifdef BLIS_ENABLE_OPENMP +/* +* Multihreaded AVX512 DAXPYF kernel with fuse factor 32 +*/ +void bli_daxpyf_zen_int32_avx512_mt + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /* + Initializing the number of thread to one + to avoid compiler warnings + */ + dim_t nt = 1; + /* + For the given problem size and architecture, the function + returns the optimum number of threads with AOCL dynamic enabled + else it returns the number of threads requested by the user. + */ + bli_nthreads_l1f + ( + BLIS_AXPYF_KER, + BLIS_DOUBLE, + BLIS_DOUBLE, + bli_arch_query_id(), + m, + &nt + ); + + _Pragma("omp parallel num_threads(nt)") + { + const dim_t tid = omp_get_thread_num(); + const dim_t nt_real = omp_get_num_threads(); + // if num threads requested and num thread available + // is not same then use single thread + if( nt_real != nt ) + { + if( tid == 0 ) + { + bli_daxpyf_zen_int32_avx512 + ( + conja, + conjx, + m, + b_n, + alpha, + a, + inca, + lda, + x, + incx, + y, + incy, + cntx + ); + } + } + else + { + dim_t job_per_thread, offset; + + // Obtain the job-size and region for compute + // Calculate y_start and a_start for current thread + bli_normfv_thread_partition( m, nt_real, &offset, &job_per_thread, 32, incy, tid ); + double* restrict y_start = y + offset; + bli_normfv_thread_partition( m, nt_real, &offset, &job_per_thread, 32, inca, tid ); + double* restrict a_start = a + offset; + + // call axpyf kernel + bli_daxpyf_zen_int32_avx512 + ( + conja, + conjx, + job_per_thread, + b_n, + alpha, + a_start, + inca, + lda, + x, + incx, + y_start, + incy, + cntx + ); + } + } +} +#endif + + +void bli_zaxpyf_zen_int_2_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + dim_t fuse_fac = 2; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a sequence of calls to zaxpyf kernels, with fuse-factor + // 4 and 2 and a single call to zaxpyv, based on the need. + if ( b_n != fuse_fac ) + { + dcomplex *a1 = a; + dcomplex *chi1 = x; + dcomplex *y1 = y; + dcomplex alpha_chi1; + + // Vectorization of alpha scaling of X + __m128d x_vec, alpha_real, alpha_imag, temp[2]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd((double *)alpha + 1); + + x_vec = _mm_loadu_pd((double *)chi1); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + x_vec = _mm_xor_pd(conj_set, x_vec); + } + + temp[0] = _mm_mul_pd(x_vec, alpha_real); + temp[1] = _mm_mul_pd(x_vec, alpha_imag); + + temp[1] = _mm_permute_pd(temp[1], 0b01); + + temp[0] = _mm_addsub_pd(temp[0], temp[1]); + + _mm_storeu_pd((double *)&alpha_chi1, temp[0]); + + bli_zaxpyv_zen_int_avx512 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *a_ptr[2]; + double *y0 = (double *)y; + + a_ptr[0] = (double *)a; + a_ptr[1] = (double *)(a + 1 * lda); + + /* Alpha scaling of X can be vectorized + irrespective of the incx and should + be avoided when alpha is 1 */ + __m128d x_vec[2]; + + x_vec[0] = _mm_loadu_pd((double *)(x + 0 * incx)); + x_vec[1] = _mm_loadu_pd((double *)(x + 1 * incx)); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + // The sequence of xor operations flip the sign bit + // of imaginary components in X vector + x_vec[0] = _mm_xor_pd(conj_set, x_vec[0]); + x_vec[1] = _mm_xor_pd(conj_set, x_vec[1]); + } + + // Special case handling when alpha == -1 + 0i + if( alpha->real == -1.0 && alpha->imag == 0.0 ) + { + __m128d zero_reg = _mm_setzero_pd(); + + x_vec[0] = _mm_sub_pd(zero_reg, x_vec[0]); + x_vec[1] = _mm_sub_pd(zero_reg, x_vec[1]); + } + // General case of scaling with alpha + else if (!(bli_zeq1(*alpha))) + { + __m128d alpha_real, alpha_imag, temp[2]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd(((double *)alpha) + 1); + + // Scaling with imaginary part of alpha + temp[0] = _mm_mul_pd(x_vec[0], alpha_imag); + temp[1] = _mm_mul_pd(x_vec[1], alpha_imag); + + // Scaling with real part of alpha + x_vec[0] = _mm_mul_pd(x_vec[0], alpha_real); + x_vec[1] = _mm_mul_pd(x_vec[1], alpha_real); + + // Permuting the registers to get the following pattern + // t[0] : xI0*alphaI + // xR0*alphaI, and so on + temp[0] = _mm_permute_pd(temp[0], 0x01); + temp[1] = _mm_permute_pd(temp[1], 0x01); + + // Addsub to complete the complex arithmetic as such: + // x_vec[0] : xR0*alphaR - xI0*alphaI + // xI0*alphaR + xR0*alphaI, and so on + x_vec[0] = _mm_addsub_pd(x_vec[0], temp[0]); + x_vec[1] = _mm_addsub_pd(x_vec[1], temp[1]); + } + + if ( (inca == 1) && (incy == 1) ) + { + // Temporary registers to store permuted alpha*X values + __m128d temp[2]; + + temp[0] = _mm_shuffle_pd(x_vec[0], x_vec[0], 0x01); + temp[1] = _mm_shuffle_pd(x_vec[1], x_vec[1], 0x01); + + // Declaring 4 registers, for re-use over the loops + // alpha_x_real[0] = xR0*alphaR xR0*alphaR ... + // alpah_x_imag[0] = xI0*alphaI xI0*alphaI ... + __m512d alpha_x_real[2], alpha_x_imag[2]; + + alpha_x_real[0] = _mm512_broadcastsd_pd(x_vec[0]); + alpha_x_real[1] = _mm512_broadcastsd_pd(x_vec[1]); + + alpha_x_imag[0] = _mm512_broadcastsd_pd(temp[0]); + alpha_x_imag[1] = _mm512_broadcastsd_pd(temp[1]); + + // Registers to load A, accumulate real and imag scaling separately + __m512d a_vec[2]; + __m512d real_acc, imag_acc, y_vec; + __m512d zero_reg = _mm512_setzero_pd(); + + // Execute the loops is m >= 4(AVX-512 unmasked code-section) + if( m >= 4 ) + { + if ( bli_is_noconj(conja) ) + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + } + } + else + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + } + } + } + if( i < m ) + { + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + if( bli_is_noconj(conja) ) + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + else + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + } + } + else + { + // Perform the computation with 128-bit registers, + // since dcomplex is 128 bits in size + __m128d a_vec[2], y_vec, real_acc, imag_acc, temp[2]; + + // Unpacking and storing real and imaginary components + // of alpha*X stored in x_vec[0...7] + temp[0] = _mm_unpackhi_pd(x_vec[0], x_vec[0]); + temp[1] = _mm_unpackhi_pd(x_vec[1], x_vec[1]); + + x_vec[0] = _mm_unpacklo_pd(x_vec[0], x_vec[0]); + x_vec[1] = _mm_unpacklo_pd(x_vec[1], x_vec[1]); + + if ( bli_is_noconj(conja) ) + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm_permute_pd(imag_acc, 0b01); + real_acc = _mm_addsub_pd(real_acc, imag_acc); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + } + } + else + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + real_acc = _mm_permute_pd(real_acc, 0b01); + real_acc = _mm_addsub_pd(imag_acc, real_acc); + real_acc = _mm_permute_pd(real_acc, 0b01); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + } + } + } +} + +void bli_zaxpyf_zen_int_4_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + dim_t fuse_fac = 4; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a sequence of calls to zaxpyf kernels, with fuse-factor + // 2 and a single call to zaxpyv, based on the need. + if ( b_n != fuse_fac ) + { + dcomplex *a1 = a; + dcomplex *chi1 = x; + dcomplex *y1 = y; + dcomplex alpha_chi1; + + // Buggy, try to mimic 8 kernel + if( b_n >= 2 ) + { + bli_zaxpyf_zen_int_2_avx512 + ( + conja, + conjx, + m, + (dim_t)2, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 2*lda; + chi1 += 2*incx; + b_n -= 2; + } + + if( b_n == 1 ) + { + // Vectorization of alpha scaling of X + __m128d x_vec, alpha_real, alpha_imag, temp[2]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd((double *)alpha + 1); + + x_vec = _mm_loadu_pd((double *)chi1); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + x_vec = _mm_xor_pd(conj_set, x_vec); + } + + temp[0] = _mm_mul_pd(x_vec, alpha_real); + temp[1] = _mm_mul_pd(x_vec, alpha_imag); + + temp[1] = _mm_permute_pd(temp[1], 0b01); + + temp[0] = _mm_addsub_pd(temp[0], temp[1]); + + _mm_storeu_pd((double *)&alpha_chi1, temp[0]); + + bli_zaxpyv_zen_int_avx512 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *a_ptr[4]; + double *y0 = (double *)y; + + a_ptr[0] = (double *)a; + a_ptr[1] = (double *)(a + 1 * lda); + a_ptr[2] = (double *)(a + 2 * lda); + a_ptr[3] = (double *)(a + 3 * lda); + + /* Alpha scaling of X can be vectorized + irrespective of the incx and should + be avoided when alpha is 1 */ + __m128d x_vec[4]; + + x_vec[0] = _mm_loadu_pd((double *)(x + 0 * incx)); + x_vec[1] = _mm_loadu_pd((double *)(x + 1 * incx)); + x_vec[2] = _mm_loadu_pd((double *)(x + 2 * incx)); + x_vec[3] = _mm_loadu_pd((double *)(x + 3 * incx)); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + // The sequence of xor operations flip the sign bit + // of imaginary components in X vector + x_vec[0] = _mm_xor_pd(conj_set, x_vec[0]); + x_vec[1] = _mm_xor_pd(conj_set, x_vec[1]); + x_vec[2] = _mm_xor_pd(conj_set, x_vec[2]); + x_vec[3] = _mm_xor_pd(conj_set, x_vec[3]); + } + + // Special case handling when alpha == -1 + 0i + if( alpha->real == -1.0 && alpha->imag == 0.0 ) + { + __m128d zero_reg = _mm_setzero_pd(); + + x_vec[0] = _mm_sub_pd(zero_reg, x_vec[0]); + x_vec[1] = _mm_sub_pd(zero_reg, x_vec[1]); + x_vec[2] = _mm_sub_pd(zero_reg, x_vec[2]); + x_vec[3] = _mm_sub_pd(zero_reg, x_vec[3]); + } + // General case of scaling with alpha + else if (!(bli_zeq1(*alpha))) + { + __m128d alpha_real, alpha_imag, temp[4]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd(((double *)alpha) + 1); + + // Scaling with imaginary part of alpha + temp[0] = _mm_mul_pd(x_vec[0], alpha_imag); + temp[1] = _mm_mul_pd(x_vec[1], alpha_imag); + temp[2] = _mm_mul_pd(x_vec[2], alpha_imag); + temp[3] = _mm_mul_pd(x_vec[3], alpha_imag); + + // Scaling with real part of alpha + x_vec[0] = _mm_mul_pd(x_vec[0], alpha_real); + x_vec[1] = _mm_mul_pd(x_vec[1], alpha_real); + x_vec[2] = _mm_mul_pd(x_vec[2], alpha_real); + x_vec[3] = _mm_mul_pd(x_vec[3], alpha_real); + + // Permuting the registers to get the following pattern + // t[0] : xI0*alphaI + // xR0*alphaI, and so on + temp[0] = _mm_permute_pd(temp[0], 0x01); + temp[1] = _mm_permute_pd(temp[1], 0x01); + temp[2] = _mm_permute_pd(temp[2], 0x01); + temp[3] = _mm_permute_pd(temp[3], 0x01); + + // Addsub to complete the complex arithmetic as such: + // x_vec[0] : xR0*alphaR - xI0*alphaI + // xI0*alphaR + xR0*alphaI, and so on + x_vec[0] = _mm_addsub_pd(x_vec[0], temp[0]); + x_vec[1] = _mm_addsub_pd(x_vec[1], temp[1]); + x_vec[2] = _mm_addsub_pd(x_vec[2], temp[2]); + x_vec[3] = _mm_addsub_pd(x_vec[3], temp[3]); + } + + if ( (inca == 1) && (incy == 1) ) + { + // Temporary registers to store permuted alpha*X values + __m128d temp[4]; + + temp[0] = _mm_shuffle_pd(x_vec[0], x_vec[0], 0x01); + temp[1] = _mm_shuffle_pd(x_vec[1], x_vec[1], 0x01); + temp[2] = _mm_shuffle_pd(x_vec[2], x_vec[2], 0x01); + temp[3] = _mm_shuffle_pd(x_vec[3], x_vec[3], 0x01); + + // Declaring 8 registers, for re-use over the loops + // alpha_x_real[0] = xR0*alphaR xR0*alphaR ... + // alpah_x_imag[0] = xI0*alphaI xI0*alphaI ... + __m512d alpha_x_real[4], alpha_x_imag[4]; + + alpha_x_real[0] = _mm512_broadcastsd_pd(x_vec[0]); + alpha_x_real[1] = _mm512_broadcastsd_pd(x_vec[1]); + alpha_x_real[2] = _mm512_broadcastsd_pd(x_vec[2]); + alpha_x_real[3] = _mm512_broadcastsd_pd(x_vec[3]); + + alpha_x_imag[0] = _mm512_broadcastsd_pd(temp[0]); + alpha_x_imag[1] = _mm512_broadcastsd_pd(temp[1]); + alpha_x_imag[2] = _mm512_broadcastsd_pd(temp[2]); + alpha_x_imag[3] = _mm512_broadcastsd_pd(temp[3]); + + // Registers to load A, accumulate real and imag scaling separately + __m512d a_vec[4]; + __m512d real_acc, imag_acc, y_vec; + __m512d zero_reg = _mm512_setzero_pd(); + + // Execute the loops is m >= 4(AVX-512 unmasked code-section) + if( m >= 4 ) + { + if ( bli_is_noconj(conja) ) + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[2] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[3] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + a_ptr[2] += 16; + a_ptr[3] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; + } + } + else + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[2] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[3] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + a_ptr[2] += 16; + a_ptr[3] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; + } + } + } + if( i < m ) + { + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + if( bli_is_noconj(conja) ) + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[2]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + else + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[2]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + } + } + else + { + // Perform the computation with 128-bit registers, + // since dcomplex is 128 bits in size + __m128d a_vec[4], y_vec, real_acc, imag_acc, temp[4]; + + // Unpacking and storing real and imaginary components + // of alpha*X stored in x_vec[0...7] + temp[0] = _mm_unpackhi_pd(x_vec[0], x_vec[0]); + temp[1] = _mm_unpackhi_pd(x_vec[1], x_vec[1]); + temp[2] = _mm_unpackhi_pd(x_vec[2], x_vec[2]); + temp[3] = _mm_unpackhi_pd(x_vec[3], x_vec[3]); + + x_vec[0] = _mm_unpacklo_pd(x_vec[0], x_vec[0]); + x_vec[1] = _mm_unpacklo_pd(x_vec[1], x_vec[1]); + x_vec[2] = _mm_unpacklo_pd(x_vec[2], x_vec[2]); + x_vec[3] = _mm_unpacklo_pd(x_vec[3], x_vec[3]); + + if ( bli_is_noconj(conja) ) + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + a_vec[2] = _mm_loadu_pd(a_ptr[2]); + a_vec[3] = _mm_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[2], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[2], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[3], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[3], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm_permute_pd(imag_acc, 0b01); + real_acc = _mm_addsub_pd(real_acc, imag_acc); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; + } + } + else + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + a_vec[2] = _mm_loadu_pd(a_ptr[2]); + a_vec[3] = _mm_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[2], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[2], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[3], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[3], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + real_acc = _mm_permute_pd(real_acc, 0b01); + real_acc = _mm_addsub_pd(imag_acc, real_acc); + real_acc = _mm_permute_pd(real_acc, 0b01); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; + } + } + } +} + +void bli_zaxpyf_zen_int_8_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + dim_t fuse_fac = 8; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a sequence of calls to zaxpyf kernels, with fuse-factor + // 4 and 2 and a single call to zaxpyv, based on the need. + if ( b_n < fuse_fac ) + { + dcomplex *a1 = a; + dcomplex *chi1 = x; + dcomplex *y1 = y; + dcomplex alpha_chi1; + + if( b_n >= 4 ) + { + bli_zaxpyf_zen_int_4_avx512 + ( + conja, + conjx, + m, + (dim_t)4, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 4*lda; + chi1 += 4*incx; + b_n -= 4; + } + + // Buggy, try to mimic 8 kernel + if( b_n >= 2 ) + { + bli_zaxpyf_zen_int_2_avx512 + ( + conja, + conjx, + m, + (dim_t)2, + alpha, + a1, inca, lda, + chi1, incx, + y1, incy, + cntx + ); + + a1 += 2*lda; + chi1 += 2*incx; + b_n -= 2; + } + + if( b_n == 1 ) + { + // Vectorization of alpha scaling of X + __m128d x_vec, alpha_real, alpha_imag, temp[2]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd((double *)alpha + 1); + + x_vec = _mm_loadu_pd((double *)chi1); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + x_vec = _mm_xor_pd(conj_set, x_vec); + } + + temp[0] = _mm_mul_pd(x_vec, alpha_real); + temp[1] = _mm_mul_pd(x_vec, alpha_imag); + + temp[1] = _mm_permute_pd(temp[1], 0b01); + + temp[0] = _mm_addsub_pd(temp[0], temp[1]); + + _mm_storeu_pd((double *)&alpha_chi1, temp[0]); + + bli_zaxpyv_zen_int_avx512 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + else if ( b_n > fuse_fac ) + { + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *a_ptr[8]; + double *y0 = (double *)y; + + a_ptr[0] = (double *)a; + a_ptr[1] = (double *)(a + 1 * lda); + a_ptr[2] = (double *)(a + 2 * lda); + a_ptr[3] = (double *)(a + 3 * lda); + + a_ptr[4] = (double *)(a + 4 * lda); + a_ptr[5] = (double *)(a + 5 * lda); + a_ptr[6] = (double *)(a + 6 * lda); + a_ptr[7] = (double *)(a + 7 * lda); + + /* Alpha scaling of X can be vectorized + irrespective of the incx and should + be avoided when alpha is 1 */ + __m128d x_vec[8]; + + x_vec[0] = _mm_loadu_pd((double *)(x + 0 * incx)); + x_vec[1] = _mm_loadu_pd((double *)(x + 1 * incx)); + x_vec[2] = _mm_loadu_pd((double *)(x + 2 * incx)); + x_vec[3] = _mm_loadu_pd((double *)(x + 3 * incx)); + + x_vec[4] = _mm_loadu_pd((double *)(x + 4 * incx)); + x_vec[5] = _mm_loadu_pd((double *)(x + 5 * incx)); + x_vec[6] = _mm_loadu_pd((double *)(x + 6 * incx)); + x_vec[7] = _mm_loadu_pd((double *)(x + 7 * incx)); + + if ( bli_is_conj( conjx ) ) + { + __m128d conj_set; + conj_set = _mm_set_pd(-0.0, 0.0); + + // The sequence of xor operations flip the sign bit + // of imaginary components in X vector + x_vec[0] = _mm_xor_pd(conj_set, x_vec[0]); + x_vec[1] = _mm_xor_pd(conj_set, x_vec[1]); + x_vec[2] = _mm_xor_pd(conj_set, x_vec[2]); + x_vec[3] = _mm_xor_pd(conj_set, x_vec[3]); + + x_vec[4] = _mm_xor_pd(conj_set, x_vec[4]); + x_vec[5] = _mm_xor_pd(conj_set, x_vec[5]); + x_vec[6] = _mm_xor_pd(conj_set, x_vec[6]); + x_vec[7] = _mm_xor_pd(conj_set, x_vec[7]); + + } + + // Special case handling when alpha == -1 + 0i + if( alpha->real == -1.0 && alpha->imag == 0.0 ) + { + __m128d zero_reg = _mm_setzero_pd(); + + x_vec[0] = _mm_sub_pd(zero_reg, x_vec[0]); + x_vec[1] = _mm_sub_pd(zero_reg, x_vec[1]); + x_vec[2] = _mm_sub_pd(zero_reg, x_vec[2]); + x_vec[3] = _mm_sub_pd(zero_reg, x_vec[3]); + + x_vec[4] = _mm_sub_pd(zero_reg, x_vec[4]); + x_vec[5] = _mm_sub_pd(zero_reg, x_vec[5]); + x_vec[6] = _mm_sub_pd(zero_reg, x_vec[6]); + x_vec[7] = _mm_sub_pd(zero_reg, x_vec[7]); + } + // General case of scaling with alpha + else if (!(bli_zeq1(*alpha))) + { + __m128d alpha_real, alpha_imag, temp[4]; + alpha_real = _mm_loaddup_pd((double *)alpha); + alpha_imag = _mm_loaddup_pd(((double *)alpha) + 1); + + // Scaling with imaginary part of alpha + temp[0] = _mm_mul_pd(x_vec[0], alpha_imag); + temp[1] = _mm_mul_pd(x_vec[1], alpha_imag); + temp[2] = _mm_mul_pd(x_vec[2], alpha_imag); + temp[3] = _mm_mul_pd(x_vec[3], alpha_imag); + + // Scaling with real part of alpha + x_vec[0] = _mm_mul_pd(x_vec[0], alpha_real); + x_vec[1] = _mm_mul_pd(x_vec[1], alpha_real); + x_vec[2] = _mm_mul_pd(x_vec[2], alpha_real); + x_vec[3] = _mm_mul_pd(x_vec[3], alpha_real); + + // Permuting the registers to get the following pattern + // t[0] : xI0*alphaI + // xR0*alphaI, and so on + temp[0] = _mm_permute_pd(temp[0], 0x01); + temp[1] = _mm_permute_pd(temp[1], 0x01); + temp[2] = _mm_permute_pd(temp[2], 0x01); + temp[3] = _mm_permute_pd(temp[3], 0x01); + + // Addsub to complete the complex arithmetic as such: + // x_vec[0] : xR0*alphaR - xI0*alphaI + // xI0*alphaR + xR0*alphaI, and so on + x_vec[0] = _mm_addsub_pd(x_vec[0], temp[0]); + x_vec[1] = _mm_addsub_pd(x_vec[1], temp[1]); + x_vec[2] = _mm_addsub_pd(x_vec[2], temp[2]); + x_vec[3] = _mm_addsub_pd(x_vec[3], temp[3]); + + // Scaling with imaginary part of alpha + temp[0] = _mm_mul_pd(x_vec[4], alpha_imag); + temp[1] = _mm_mul_pd(x_vec[5], alpha_imag); + temp[2] = _mm_mul_pd(x_vec[6], alpha_imag); + temp[3] = _mm_mul_pd(x_vec[7], alpha_imag); + + // Scaling with real part of alpha + x_vec[4] = _mm_mul_pd(x_vec[4], alpha_real); + x_vec[5] = _mm_mul_pd(x_vec[5], alpha_real); + x_vec[6] = _mm_mul_pd(x_vec[6], alpha_real); + x_vec[7] = _mm_mul_pd(x_vec[7], alpha_real); + + // Permuting the registers to get the following pattern + // t[0] : xI0*alphaI xR0*alphaI + temp[0] = _mm_permute_pd(temp[0], 0x01); + temp[1] = _mm_permute_pd(temp[1], 0x01); + temp[2] = _mm_permute_pd(temp[2], 0x01); + temp[3] = _mm_permute_pd(temp[3], 0x01); + + // Addsub to complete the complex arithmetic as such: + // x_vec[0] : ( xR0*alphaR - xI0*alphaI ) ( xI0*alphaR + xR0*alphaI ) + x_vec[4] = _mm_addsub_pd(x_vec[4], temp[0]); + x_vec[5] = _mm_addsub_pd(x_vec[5], temp[1]); + x_vec[6] = _mm_addsub_pd(x_vec[6], temp[2]); + x_vec[7] = _mm_addsub_pd(x_vec[7], temp[3]); + } + + if ( (inca == 1) && (incy == 1) ) + { + // Temporary registers to store permuted alpha*X values + __m128d temp[8]; + + temp[0] = _mm_shuffle_pd(x_vec[0], x_vec[0], 0x01); + temp[1] = _mm_shuffle_pd(x_vec[1], x_vec[1], 0x01); + temp[2] = _mm_shuffle_pd(x_vec[2], x_vec[2], 0x01); + temp[3] = _mm_shuffle_pd(x_vec[3], x_vec[3], 0x01); + + temp[4] = _mm_shuffle_pd(x_vec[4], x_vec[4], 0x01); + temp[5] = _mm_shuffle_pd(x_vec[5], x_vec[5], 0x01); + temp[6] = _mm_shuffle_pd(x_vec[6], x_vec[6], 0x01); + temp[7] = _mm_shuffle_pd(x_vec[7], x_vec[7], 0x01); + + // Declaring 16 registers, for re-use over the loops + // alpha_x_real[0] = xR0*alphaR xR0*alphaR ... + // alpah_x_imag[0] = xI0*alphaI xI0*alphaI ... + __m512d alpha_x_real[8], alpha_x_imag[8]; + + alpha_x_real[0] = _mm512_broadcastsd_pd(x_vec[0]); + alpha_x_real[1] = _mm512_broadcastsd_pd(x_vec[1]); + alpha_x_real[2] = _mm512_broadcastsd_pd(x_vec[2]); + alpha_x_real[3] = _mm512_broadcastsd_pd(x_vec[3]); + alpha_x_real[4] = _mm512_broadcastsd_pd(x_vec[4]); + alpha_x_real[5] = _mm512_broadcastsd_pd(x_vec[5]); + alpha_x_real[6] = _mm512_broadcastsd_pd(x_vec[6]); + alpha_x_real[7] = _mm512_broadcastsd_pd(x_vec[7]); + + alpha_x_imag[0] = _mm512_broadcastsd_pd(temp[0]); + alpha_x_imag[1] = _mm512_broadcastsd_pd(temp[1]); + alpha_x_imag[2] = _mm512_broadcastsd_pd(temp[2]); + alpha_x_imag[3] = _mm512_broadcastsd_pd(temp[3]); + alpha_x_imag[4] = _mm512_broadcastsd_pd(temp[4]); + alpha_x_imag[5] = _mm512_broadcastsd_pd(temp[5]); + alpha_x_imag[6] = _mm512_broadcastsd_pd(temp[6]); + alpha_x_imag[7] = _mm512_broadcastsd_pd(temp[7]); + + // Registers to load A, accumulate real and imag scaling separately + __m512d a_vec[4]; + __m512d real_acc, imag_acc, y_vec; + __m512d zero_reg = _mm512_setzero_pd(); + + // Execute the loops is m >= 4(AVX-512 unmasked code-section) + if( m >= 4 ) + { + if ( bli_is_noconj(conja) ) + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4]); + a_vec[1] = _mm512_loadu_pd(a_ptr[5]); + a_vec[2] = _mm512_loadu_pd(a_ptr[6]); + a_vec[3] = _mm512_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[2] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[3] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load next 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[5] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[6] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[7] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + a_ptr[2] += 16; + a_ptr[3] += 16; + a_ptr[4] += 16; + a_ptr[5] += 16; + a_ptr[6] += 16; + a_ptr[7] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4]); + a_vec[1] = _mm512_loadu_pd(a_ptr[5]); + a_vec[2] = _mm512_loadu_pd(a_ptr[6]); + a_vec[3] = _mm512_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; + a_ptr[4] += 8; + a_ptr[5] += 8; + a_ptr[6] += 8; + a_ptr[7] += 8; + } + } + else + { + for (; (i + 7) < m; i += 8) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4]); + a_vec[1] = _mm512_loadu_pd(a_ptr[5]); + a_vec[2] = _mm512_loadu_pd(a_ptr[6]); + a_vec[3] = _mm512_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + // Load next 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[1] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[2] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[3] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load next 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4] + 8); + a_vec[1] = _mm512_loadu_pd(a_ptr[5] + 8); + a_vec[2] = _mm512_loadu_pd(a_ptr[6] + 8); + a_vec[3] = _mm512_loadu_pd(a_ptr[7] + 8); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load next 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0 + 8); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0 + 8, y_vec); + + y0 += 16; + a_ptr[0] += 16; + a_ptr[1] += 16; + a_ptr[2] += 16; + a_ptr[3] += 16; + a_ptr[4] += 16; + a_ptr[5] += 16; + a_ptr[6] += 16; + a_ptr[7] += 16; + } + + for (; (i + 3) < m; i += 4) + { + // Load first 4 elements from first 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[0]); + a_vec[1] = _mm512_loadu_pd(a_ptr[1]); + a_vec[2] = _mm512_loadu_pd(a_ptr[2]); + a_vec[3] = _mm512_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load first 4 elements from next 4 columns of A + a_vec[0] = _mm512_loadu_pd(a_ptr[4]); + a_vec[1] = _mm512_loadu_pd(a_ptr[5]); + a_vec[2] = _mm512_loadu_pd(a_ptr[6]); + a_vec[3] = _mm512_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load first 4 elements of Y vector + y_vec = _mm512_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_storeu_pd(y0, y_vec); + + y0 += 8; + a_ptr[0] += 8; + a_ptr[1] += 8; + a_ptr[2] += 8; + a_ptr[3] += 8; + a_ptr[4] += 8; + a_ptr[5] += 8; + a_ptr[6] += 8; + a_ptr[7] += 8; + } + } + } + if( i < m ) + { + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + if( bli_is_noconj(conja) ) + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[2]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load remaining elements from next 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[4]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[5]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[6]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + imag_acc = _mm512_fmaddsub_pd(zero_reg, zero_reg, imag_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + else + { + // Load remaining elements from first 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[0]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[1]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[2]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_mul_pd(a_vec[0], alpha_x_real[0]); + imag_acc = _mm512_mul_pd(a_vec[0], alpha_x_imag[0]); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[1], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[1], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[2], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[2], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[3], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[3], imag_acc); + + // Load remaining elements from next 4 columns of A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, a_ptr[4]); + a_vec[1] = _mm512_maskz_loadu_pd(m_mask, a_ptr[5]); + a_vec[2] = _mm512_maskz_loadu_pd(m_mask, a_ptr[6]); + a_vec[3] = _mm512_maskz_loadu_pd(m_mask, a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_real[4], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[0], alpha_x_imag[4], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_real[5], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[1], alpha_x_imag[5], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_real[6], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[2], alpha_x_imag[6], imag_acc); + + real_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_real[7], real_acc); + imag_acc = _mm512_fmadd_pd(a_vec[3], alpha_x_imag[7], imag_acc); + + // Load remaining elements of Y vector + y_vec = _mm512_maskz_loadu_pd(m_mask, y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm512_permute_pd(imag_acc, 0x55); + real_acc = _mm512_fmsubadd_pd(zero_reg, zero_reg, real_acc); + real_acc = _mm512_add_pd(real_acc, imag_acc); + + y_vec = _mm512_add_pd(y_vec, real_acc); + + // Store onto Y vector + _mm512_mask_storeu_pd(y0, m_mask, y_vec); + } + } + } + else + { + // Perform the computation with 128-bit registers, + // since dcomplex is 128 bits in size + __m128d a_vec[4], y_vec, real_acc, imag_acc, temp[8]; + + // Unpacking and storing real and imaginary components + // of alpha*X stored in x_vec[0...7] + temp[0] = _mm_unpackhi_pd(x_vec[0], x_vec[0]); + temp[1] = _mm_unpackhi_pd(x_vec[1], x_vec[1]); + temp[2] = _mm_unpackhi_pd(x_vec[2], x_vec[2]); + temp[3] = _mm_unpackhi_pd(x_vec[3], x_vec[3]); + temp[4] = _mm_unpackhi_pd(x_vec[4], x_vec[4]); + temp[5] = _mm_unpackhi_pd(x_vec[5], x_vec[5]); + temp[6] = _mm_unpackhi_pd(x_vec[6], x_vec[6]); + temp[7] = _mm_unpackhi_pd(x_vec[7], x_vec[7]); + + x_vec[0] = _mm_unpacklo_pd(x_vec[0], x_vec[0]); + x_vec[1] = _mm_unpacklo_pd(x_vec[1], x_vec[1]); + x_vec[2] = _mm_unpacklo_pd(x_vec[2], x_vec[2]); + x_vec[3] = _mm_unpacklo_pd(x_vec[3], x_vec[3]); + x_vec[4] = _mm_unpacklo_pd(x_vec[4], x_vec[4]); + x_vec[5] = _mm_unpacklo_pd(x_vec[5], x_vec[5]); + x_vec[6] = _mm_unpacklo_pd(x_vec[6], x_vec[6]); + x_vec[7] = _mm_unpacklo_pd(x_vec[7], x_vec[7]); + + if ( bli_is_noconj(conja) ) + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + a_vec[2] = _mm_loadu_pd(a_ptr[2]); + a_vec[3] = _mm_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[2], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[2], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[3], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[3], imag_acc); + + // Load elements from next 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[4]); + a_vec[1] = _mm_loadu_pd(a_ptr[5]); + a_vec[2] = _mm_loadu_pd(a_ptr[6]); + a_vec[3] = _mm_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_fmadd_pd(a_vec[0], x_vec[4], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[0], temp[4], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[5], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[5], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[6], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[6], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[7], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[7], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + imag_acc = _mm_permute_pd(imag_acc, 0b01); + real_acc = _mm_addsub_pd(real_acc, imag_acc); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; + a_ptr[4] += 2 * inca; + a_ptr[5] += 2 * inca; + a_ptr[6] += 2 * inca; + a_ptr[7] += 2 * inca; + } + } + else + { + for (; i < m; i++) + { + // Load elements from first 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[0]); + a_vec[1] = _mm_loadu_pd(a_ptr[1]); + a_vec[2] = _mm_loadu_pd(a_ptr[2]); + a_vec[3] = _mm_loadu_pd(a_ptr[3]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_mul_pd(a_vec[0], x_vec[0]); + imag_acc = _mm_mul_pd(a_vec[0], temp[0]); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[1], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[1], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[2], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[2], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[3], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[3], imag_acc); + + // Load elements from next 4 columns of A + a_vec[0] = _mm_loadu_pd(a_ptr[4]); + a_vec[1] = _mm_loadu_pd(a_ptr[5]); + a_vec[2] = _mm_loadu_pd(a_ptr[6]); + a_vec[3] = _mm_loadu_pd(a_ptr[7]); + + // Multiply the loaded columns of A by alpha*X(real and imag) + real_acc = _mm_fmadd_pd(a_vec[0], x_vec[4], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[0], temp[4], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[1], x_vec[5], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[1], temp[5], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[2], x_vec[6], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[2], temp[6], imag_acc); + + real_acc = _mm_fmadd_pd(a_vec[3], x_vec[7], real_acc); + imag_acc = _mm_fmadd_pd(a_vec[3], temp[7], imag_acc); + + // Load Y vector + y_vec = _mm_loadu_pd(y0); + + // Permute and reduce the complex and real parts + real_acc = _mm_permute_pd(real_acc, 0b01); + real_acc = _mm_addsub_pd(imag_acc, real_acc); + real_acc = _mm_permute_pd(real_acc, 0b01); + + y_vec = _mm_add_pd(y_vec, real_acc); + + // Store Y vector + _mm_storeu_pd(y0, y_vec); + + y0 += 2 * incy; + a_ptr[0] += 2 * inca; + a_ptr[1] += 2 * inca; + a_ptr[2] += 2 * inca; + a_ptr[3] += 2 * inca; + a_ptr[4] += 2 * inca; + a_ptr[5] += 2 * inca; + a_ptr[6] += 2 * inca; + a_ptr[7] += 2 * inca; + } + } + } +} + + +void bli_daxpyf_zen_int2_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y0, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 8; + dim_t i = 0; + __m512d chi[2]; + __m512d av[2]; + __m512d yv; + double* as[2] __attribute__((aligned(64))); + double* y = y0; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) + return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != 2 ) + { + // Definition of function pointer + daxpyv_ker_ft axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (i )*lda; + double* chi1 = x + (i )*incx; + double alphavchi1; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + axpyv_ker_ptr + ( + conja, + m, + &alphavchi1, + a1, inca, + y, incy, + cntx + ); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // Load the address of the first element of each column into an array. + as[0] = a + (0 * lda); + as[1] = a + (1 * lda); + + // Multiple the elements in the vector with alpha and broadcast the results into __m512 variables + chi[0] = _mm512_set1_pd( (*alpha) * (*(x + 0 * incx)) ); + chi[1] = _mm512_set1_pd( (*alpha) * (*(x + 1 * incx)) ); + + // If there are vectorized iterations, perform them with vector instructions. + // The execution can be vectorized only when the strides are equal to 1 + if ( inca == 1 && incy == 1 ) + { + + for ( ; i + n_elem_per_reg <= m; i += n_elem_per_reg) + { + // The existing value in y is loaded into a __m512 variable. + yv = _mm512_loadu_pd( y ); + + // Load 12 elements from each column into __m512 variables + // The elements will be stored using the pointers in the array "as" + av[0] = _mm512_loadu_pd( as[0] ); + av[1] = _mm512_loadu_pd( as[1] ); + + // After loading the elements into the __m512 variable, the pointer will be updated + as[0] += n_elem_per_reg; + as[1] += n_elem_per_reg; + + // fused-multiplication-add is used to multiple 8 elements in each column of the matrix + // with one element in the vector and store the results in multiple __m512 variables. + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_storeu_pd( (double *)(y ), yv ); + + y += n_elem_per_reg; + + } + + // Handling Fringe cases + if ( m > i ) + { + // Declaring and initialising the mask + __mmask8 m_mask = (1 << (m - i)) - 1; + + yv= _mm512_mask_loadu_pd( chi[0], m_mask, y ); + + // Load the remaining elements in each column into __m512 variables using mask operations + av[0] = _mm512_maskz_loadu_pd( m_mask, as[0] ); + av[1] = _mm512_maskz_loadu_pd( m_mask, as[1] ); + + // Use fused-multiply-add operations to multiple the columns in the matrix with the elements of the vector + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_mask_storeu_pd( (double *)(y ), m_mask, yv ); + } + } + // To handle inputs that cannot be vectorized + else + { + double yc = *y; + double chi_s[2]; + + // The elements in the vector are multipled with alpha and the result is stored in an array + chi_s[0] = *(x + 0 * incx) * *alpha; + chi_s[1] = *(x + 1 * incx) * *alpha; + + + // A loop is used to iterate over the matrix row-by-row. + // The elements in each row are multipled with each value in the array + for ( i = 0; (i + 0) < m ; ++i ) + { + yc = *y; + + yc += chi_s[0] * (*as[0]); + as[0] += inca; + + yc += chi_s[1] * (*as[1]); + as[1] += inca; + + *y = yc; + y += incy; + } + } +} + +void bli_daxpyf_zen_int4_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y0, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 8; + dim_t i = 0; + __m512d chi[4]; + __m512d av[4]; + __m512d yv; + double* as[4] __attribute__((aligned(64))); + double* y = y0; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) + return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != 4 ) + { + // Definition of function pointer + daxpyv_ker_ft axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (i )*lda; + double* chi1 = x + (i )*incx; + double alphavchi1; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + axpyv_ker_ptr + ( + conja, + m, + &alphavchi1, + a1, inca, + y, incy, + cntx + ); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // Load the address of the first element of each column into an array. + as[0] = a + (0 * lda); + as[1] = a + (1 * lda); + as[2] = a + (2 * lda); + as[3] = a + (3 * lda); + + // Multiple the elements in the vector with alpha and broadcast the results into __m512 variables + chi[0] = _mm512_set1_pd( (*alpha) * (*(x + 0 * incx)) ); + chi[1] = _mm512_set1_pd( (*alpha) * (*(x + 1 * incx)) ); + chi[2] = _mm512_set1_pd( (*alpha) * (*(x + 2 * incx)) ); + chi[3] = _mm512_set1_pd( (*alpha) * (*(x + 3 * incx)) ); + + // If there are vectorized iterations, perform them with vector instructions. + // The execution can be vectorized only when the strides are equal to 1 + if ( inca == 1 && incy == 1 ) + { + + for ( ; i + n_elem_per_reg <= m; i += n_elem_per_reg) + { + // The existing value in y is loaded into a __m512 variable. + yv = _mm512_loadu_pd( y ); + + // Load 12 elements from each column into __m512 variables + // The elements will be stored using the pointers in the array "as" + av[0] = _mm512_loadu_pd( as[0] ); + av[1] = _mm512_loadu_pd( as[1] ); + av[2] = _mm512_loadu_pd( as[2] ); + av[3] = _mm512_loadu_pd( as[3] ); + + // After loading the elements into the __m512 variable, the pointer will be updated + as[0] += n_elem_per_reg; + as[1] += n_elem_per_reg; + as[2] += n_elem_per_reg; + as[3] += n_elem_per_reg; + + // fused-multiplication-add is used to multiple 8 elements in each column of the matrix + // with one element in the vector and store the results in multiple __m512 variables. + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + yv = _mm512_fmadd_pd( av[2], chi[2], yv ); + yv = _mm512_fmadd_pd( av[3], chi[3], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_storeu_pd( (double *)(y ), yv ); + + y += n_elem_per_reg; + + } + + // Handling Fringe cases + if ( m > i ) + { + // Declaring and initialising the mask + __mmask8 m_mask = (1 << (m - i)) - 1; + + yv= _mm512_mask_loadu_pd( chi[0], m_mask, y ); + + // Load the remaining elements in each column into __m512 variables using mask operations + av[0] = _mm512_maskz_loadu_pd( m_mask, as[0] ); + av[1] = _mm512_maskz_loadu_pd( m_mask, as[1] ); + av[2] = _mm512_maskz_loadu_pd( m_mask, as[2] ); + av[3] = _mm512_maskz_loadu_pd( m_mask, as[3] ); + + // Use fused-multiply-add operations to multiple the columns in the matrix with the elements of the vector + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + yv = _mm512_fmadd_pd( av[2], chi[2], yv ); + yv = _mm512_fmadd_pd( av[3], chi[3], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_mask_storeu_pd( (double *)(y ), m_mask, yv ); + } + } + // To handle inputs that cannot be vectorized + else + { + double yc = *y; + double chi_s[4]; + + // The elements in the vector are multipled with alpha and the result is stored in an array + chi_s[0] = *(x + 0 * incx) * *alpha; + chi_s[1] = *(x + 1 * incx) * *alpha; + chi_s[2] = *(x + 2 * incx) * *alpha; + chi_s[3] = *(x + 3 * incx) * *alpha; + + + // A loop is used to iterate over the matrix row-by-row. + // The elements in each row are multipled with each value in the array + for ( i = 0; (i + 0) < m ; ++i ) + { + yc = *y; + + yc += chi_s[0] * (*as[0]); + as[0] += inca; + + yc += chi_s[1] * (*as[1]); + as[1] += inca; + + yc += chi_s[2] * (*as[2]); + as[2] += inca; + + yc += chi_s[3] * (*as[3]); + as[3] += inca; + + *y = yc; + y += incy; + } + } +} + +void bli_daxpyf_zen_int8_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y0, inc_t incy, + cntx_t* restrict cntx + ) +{ + + const dim_t n_elem_per_reg = 8; + dim_t i = 0; + double* y = y0; + double* as[8] __attribute__((aligned(64))); + __m512d chi[8]; + __m512d av[8]; + __m512d yv[8]; + + + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) + return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != 8 ) + { + // Definition of function pointer + daxpyv_ker_ft axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (i )*lda; + double* chi1 = x + (i )*incx; + double alphavchi1; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + axpyv_ker_ptr + ( + conja, + m, + &alphavchi1, + a1, inca, + y, incy, + cntx + ); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // Load the address of the first element of each column into an array. + as[0] = a + (0 * lda); + as[1] = a + (1 * lda); + as[2] = a + (2 * lda); + as[3] = a + (3 * lda); + + as[4] = a + (4 * lda); + as[5] = a + (5 * lda); + as[6] = a + (6 * lda); + as[7] = a + (7 * lda); + + // Multiple the elements in the vector with alpha and broadcast the results into __m512 variables + chi[0] = _mm512_set1_pd( (*alpha) * (*(x + 0 * incx)) ); + chi[1] = _mm512_set1_pd( (*alpha) * (*(x + 1 * incx)) ); + chi[2] = _mm512_set1_pd( (*alpha) * (*(x + 2 * incx)) ); + chi[3] = _mm512_set1_pd( (*alpha) * (*(x + 3 * incx)) ); + + chi[4] = _mm512_set1_pd( (*alpha) * (*(x + 4 * incx)) ); + chi[5] = _mm512_set1_pd( (*alpha) * (*(x + 5 * incx)) ); + chi[6] = _mm512_set1_pd( (*alpha) * (*(x + 6 * incx)) ); + chi[7] = _mm512_set1_pd( (*alpha) * (*(x + 7 * incx)) ); + + + // If there are vectorized iterations, perform them with vector instructions. + // The execution can be vectorized only when the strides are equal to 1 + if ( inca == 1 && incy == 1 ) + { + // Execute the loop with 8 rows of the matrix at a time. + // The loop is executed until less than 8 elements are remaining + for ( ; i + n_elem_per_reg <= m; i += n_elem_per_reg) + { + // Initialize the value of yv[7] to zero + // It will be used to store the result + yv[7] = _mm512_setzero_pd(); + + // Load 8 elements from each column into __m512 variables + // The elements will be stored using the pointers in the array as[] + av[0] = _mm512_loadu_pd( as[0] ); + av[1] = _mm512_loadu_pd( as[1] ); + av[2] = _mm512_loadu_pd( as[2] ); + av[3] = _mm512_loadu_pd( as[3] ); + av[4] = _mm512_loadu_pd( as[4] ); + av[5] = _mm512_loadu_pd( as[5] ); + av[6] = _mm512_loadu_pd( as[6] ); + av[7] = _mm512_loadu_pd( as[7] ); + + // After loading the elements into the __m512 variable, the pointer will be updated + as[0] += n_elem_per_reg; + as[1] += n_elem_per_reg; + as[2] += n_elem_per_reg; + as[3] += n_elem_per_reg; + as[4] += n_elem_per_reg; + as[5] += n_elem_per_reg; + as[6] += n_elem_per_reg; + as[7] += n_elem_per_reg; + + // fused-multiplication-add is used to multiple 8 elements in each column of the matrix + // with one element in the vector and store the results in multiple __m512 variables. + // Use of multiple __m512 variables reduces operand dependancy between the instructions. + yv[0] = _mm512_fmadd_pd( av[0], chi[0], yv[7] ); + yv[1] = _mm512_fmadd_pd( av[1], chi[1], yv[7] ); + yv[2] = _mm512_fmadd_pd( av[2], chi[2], yv[7] ); + yv[3] = _mm512_fmadd_pd( av[3], chi[3], yv[7] ); + yv[4] = _mm512_fmadd_pd( av[4], chi[4], yv[7] ); + yv[5] = _mm512_fmadd_pd( av[5], chi[5], yv[7] ); + yv[6] = _mm512_fmadd_pd( av[6], chi[6], yv[7] ); + yv[7] = _mm512_fmadd_pd( av[7], chi[7], yv[7] ); + + // The values in the 8 __m512 variables together and store it in a __m512 variable. + yv[0] = _mm512_add_pd( yv[0], yv[1] ); + yv[2] = _mm512_add_pd( yv[2], yv[3] ); + yv[4] = _mm512_add_pd( yv[4], yv[5] ); + yv[6] = _mm512_add_pd( yv[6], yv[7] ); + + // The existing value in y is loaded into a __m512 variable. + // It is then added together with the other __m512 variables. + yv[7] = _mm512_loadu_pd( y ); + yv[3] = _mm512_add_pd( yv[0], yv[2] ); + yv[5] = _mm512_add_pd( yv[4], yv[6] ); + + yv[1] = _mm512_add_pd( yv[3], yv[5] ); + yv[7] = _mm512_add_pd( yv[1], yv[7] ); + + // Store the result from the __m512 variable into the destination + _mm512_storeu_pd( (double *)(y ), yv[7] ); + + y += n_elem_per_reg; + + } + + // Handling Fringe cases using masked operations + if ( m > i ) + { + // Declaring and initialising the mask + __mmask8 m_mask = (1 << (m - i)) - 1; + + yv[7] = _mm512_setzero_pd(); + + // Load the remaining elements in each column into __m512 variables using mask operations + av[0] = _mm512_maskz_loadu_pd( m_mask, as[0] ); + av[1] = _mm512_maskz_loadu_pd( m_mask, as[1] ); + av[2] = _mm512_maskz_loadu_pd( m_mask, as[2] ); + av[3] = _mm512_maskz_loadu_pd( m_mask, as[3] ); + av[4] = _mm512_maskz_loadu_pd( m_mask, as[4] ); + av[5] = _mm512_maskz_loadu_pd( m_mask, as[5] ); + av[6] = _mm512_maskz_loadu_pd( m_mask, as[6] ); + av[7] = _mm512_maskz_loadu_pd( m_mask, as[7] ); + + // Use fused-multiply-add operations to multiple the columns in the matrix with the elements of the vector + yv[0] = _mm512_fmadd_pd( av[0], chi[0], yv[7] ); + yv[1] = _mm512_fmadd_pd( av[1], chi[1], yv[7] ); + yv[2] = _mm512_fmadd_pd( av[2], chi[2], yv[7] ); + yv[3] = _mm512_fmadd_pd( av[3], chi[3], yv[7] ); + yv[4] = _mm512_fmadd_pd( av[4], chi[4], yv[7] ); + yv[5] = _mm512_fmadd_pd( av[5], chi[5], yv[7] ); + yv[6] = _mm512_fmadd_pd( av[6], chi[6], yv[7] ); + yv[7] = _mm512_fmadd_pd( av[7], chi[7], yv[7] ); + + // The values in the 8 __m512 variables together and store it in a __m512 variable + yv[0] = _mm512_add_pd( yv[0], yv[1] ); + yv[2] = _mm512_add_pd( yv[2], yv[3] ); + yv[4] = _mm512_add_pd( yv[4], yv[5] ); + yv[6] = _mm512_add_pd( yv[6], yv[7] ); + + // The existing value in y is loaded into a __m512 variable. + // It is then added together with the other __m512 variables. + yv[7]= _mm512_mask_loadu_pd( chi[0], m_mask, y ); + yv[3] = _mm512_add_pd( yv[0], yv[2] ); + yv[5] = _mm512_add_pd( yv[4], yv[6] ); + + yv[1] = _mm512_add_pd( yv[3], yv[5] ); + yv[7] = _mm512_add_pd( yv[1], yv[7] ); + + // Store the result from the __m512 variable into the destination + _mm512_mask_storeu_pd( (double *)(y ), m_mask, yv[7]); + } + } + + // To handle inputs that cannot be vectorized + else + { + double yc = *y; + double chi_s[8]; + + // The elements in the vector are multipled with alpha and the result is stored in an array + chi_s[0] = *(x + 0 * incx) * *alpha; + chi_s[1] = *(x + 1 * incx) * *alpha; + chi_s[2] = *(x + 2 * incx) * *alpha; + chi_s[3] = *(x + 3 * incx) * *alpha; + chi_s[4] = *(x + 4 * incx) * *alpha; + chi_s[5] = *(x + 5 * incx) * *alpha; + chi_s[6] = *(x + 6 * incx) * *alpha; + chi_s[7] = *(x + 7 * incx) * *alpha; + + // A loop is used to iterate over the matrix row-by-row. + // The elements in each row are multipled with each value in the array + for ( i = 0; (i + 0) < m ; i++ ) + { + yc = *y; + + yc += chi_s[0] * (*as[0]); + as[0] += inca; + + yc += chi_s[1] * (*as[1]); + as[1] += inca; + + yc += chi_s[2] * (*as[2]); + as[2] += inca; + + yc += chi_s[3] * (*as[3]); + as[3] += inca; + + yc += chi_s[4] * (*as[4]); + as[4] += inca; + + yc += chi_s[5] * (*as[5]); + as[5] += inca; + + yc += chi_s[6] * (*as[6]); + as[6] += inca; + + yc += chi_s[7] * (*as[7]); + as[7] += inca; + + *y = yc; + y += incy; + } + } +} + +void bli_daxpyf_zen_int12_avx512 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y0, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 8; + dim_t i = 0; + __m512d chi[12]; + __m512d av[12]; + __m512d yv; + double* as[12] __attribute__((aligned(64))); + double* y = y0; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) + return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != 12 ) + { + // Definition of function pointer + daxpyv_ker_ft axpyv_ker_ptr = bli_daxpyv_zen_int_avx512; + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (i )*lda; + double* chi1 = x + (i )*incx; + double alphavchi1; + + bli_dcopycjs( conjx, *chi1, alphavchi1 ); + bli_dscals( *alpha, alphavchi1 ); + + axpyv_ker_ptr + ( + conja, + m, + &alphavchi1, + a1, inca, + y, incy, + cntx + ); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // Load the address of the first element of each column into an array. + as[0] = a + (0 * lda); + as[1] = a + (1 * lda); + as[2] = a + (2 * lda); + as[3] = a + (3 * lda); + + as[4] = a + (4 * lda); + as[5] = a + (5 * lda); + as[6] = a + (6 * lda); + as[7] = a + (7 * lda); + + as[8] = a + (8 * lda); + as[9] = a + (9 * lda); + as[10] = a + (10 * lda); + as[11] = a + (11 * lda); + + // Multiple the elements in the vector with alpha and broadcast the results into __m512 variables + chi[0] = _mm512_set1_pd( (*alpha) * (*(x + 0 * incx)) ); + chi[1] = _mm512_set1_pd( (*alpha) * (*(x + 1 * incx)) ); + chi[2] = _mm512_set1_pd( (*alpha) * (*(x + 2 * incx)) ); + chi[3] = _mm512_set1_pd( (*alpha) * (*(x + 3 * incx)) ); + + chi[4] = _mm512_set1_pd( (*alpha) * (*(x + 4 * incx)) ); + chi[5] = _mm512_set1_pd( (*alpha) * (*(x + 5 * incx)) ); + chi[6] = _mm512_set1_pd( (*alpha) * (*(x + 6 * incx)) ); + chi[7] = _mm512_set1_pd( (*alpha) * (*(x + 7 * incx)) ); + + chi[8] = _mm512_set1_pd( (*alpha) * (*(x + 8 * incx)) ); + chi[9] = _mm512_set1_pd( (*alpha) * (*(x + 9 * incx)) ); + chi[10] = _mm512_set1_pd( (*alpha) * (*(x + 10 * incx)) ); + chi[11] = _mm512_set1_pd( (*alpha) * (*(x + 11 * incx)) ); + + + // If there are vectorized iterations, perform them with vector instructions. + // The execution can be vectorized only when the strides are equal to 1 + if ( inca == 1 && incy == 1 ) + { + + for ( ; i + n_elem_per_reg <= m; i += n_elem_per_reg) + { + // The existing value in y is loaded into a __m512 variable. + yv = _mm512_loadu_pd( y ); + + // Load 12 elements from each column into __m512 variables + // The elements will be stored using the pointers in the array "as" + av[0] = _mm512_loadu_pd( as[0] ); + av[1] = _mm512_loadu_pd( as[1] ); + av[2] = _mm512_loadu_pd( as[2] ); + av[3] = _mm512_loadu_pd( as[3] ); + av[4] = _mm512_loadu_pd( as[4] ); + av[5] = _mm512_loadu_pd( as[5] ); + av[6] = _mm512_loadu_pd( as[6] ); + av[7] = _mm512_loadu_pd( as[7] ); + av[8] = _mm512_loadu_pd( as[8] ); + av[9] = _mm512_loadu_pd( as[9] ); + av[10] = _mm512_loadu_pd( as[10] ); + av[11] = _mm512_loadu_pd( as[11] ); + + // After loading the elements into the __m512 variable, the pointer will be updated + as[0] += n_elem_per_reg; + as[1] += n_elem_per_reg; + as[2] += n_elem_per_reg; + as[3] += n_elem_per_reg; + as[4] += n_elem_per_reg; + as[5] += n_elem_per_reg; + as[6] += n_elem_per_reg; + as[7] += n_elem_per_reg; + as[8] += n_elem_per_reg; + as[9] += n_elem_per_reg; + as[10] += n_elem_per_reg; + as[11] += n_elem_per_reg; + + // fused-multiplication-add is used to multiple 8 elements in each column of the matrix + // with one element in the vector and store the results in multiple __m512 variables. + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + yv = _mm512_fmadd_pd( av[2], chi[2], yv ); + yv = _mm512_fmadd_pd( av[3], chi[3], yv ); + yv = _mm512_fmadd_pd( av[4], chi[4], yv ); + yv = _mm512_fmadd_pd( av[5], chi[5], yv ); + yv = _mm512_fmadd_pd( av[6], chi[6], yv ); + yv = _mm512_fmadd_pd( av[7], chi[7], yv ); + yv = _mm512_fmadd_pd( av[8], chi[8], yv ); + yv = _mm512_fmadd_pd( av[9], chi[9], yv ); + yv = _mm512_fmadd_pd( av[10], chi[10], yv ); + yv = _mm512_fmadd_pd( av[11], chi[11], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_storeu_pd( (double *)(y ), yv ); + + y += n_elem_per_reg; + + } + + // Handling Fringe cases + if ( m > i ) + { + // Declaring and initialising the mask + __mmask8 m_mask = (1 << (m - i)) - 1; + + yv= _mm512_mask_loadu_pd( chi[0], m_mask, y ); + + // Load the remaining elements in each column into __m512 variables using mask operations + av[0] = _mm512_maskz_loadu_pd( m_mask, as[0] ); + av[1] = _mm512_maskz_loadu_pd( m_mask, as[1] ); + av[2] = _mm512_maskz_loadu_pd( m_mask, as[2] ); + av[3] = _mm512_maskz_loadu_pd( m_mask, as[3] ); + av[4] = _mm512_maskz_loadu_pd( m_mask, as[4] ); + av[5] = _mm512_maskz_loadu_pd( m_mask, as[5] ); + av[6] = _mm512_maskz_loadu_pd( m_mask, as[6] ); + av[7] = _mm512_maskz_loadu_pd( m_mask, as[7] ); + av[8] = _mm512_maskz_loadu_pd( m_mask, as[8] ); + av[9] = _mm512_maskz_loadu_pd( m_mask, as[9] ); + av[10] = _mm512_maskz_loadu_pd( m_mask, as[10] ); + av[11] = _mm512_maskz_loadu_pd( m_mask, as[11] ); + + // Use fused-multiply-add operations to multiple the columns in the matrix with the elements of the vector + yv = _mm512_fmadd_pd( av[0], chi[0], yv ); + yv = _mm512_fmadd_pd( av[1], chi[1], yv ); + yv = _mm512_fmadd_pd( av[2], chi[2], yv ); + yv = _mm512_fmadd_pd( av[3], chi[3], yv ); + yv = _mm512_fmadd_pd( av[4], chi[4], yv ); + yv = _mm512_fmadd_pd( av[5], chi[5], yv ); + yv = _mm512_fmadd_pd( av[6], chi[6], yv ); + yv = _mm512_fmadd_pd( av[7], chi[7], yv ); + yv = _mm512_fmadd_pd( av[8], chi[8], yv ); + yv = _mm512_fmadd_pd( av[9], chi[9], yv ); + yv = _mm512_fmadd_pd( av[10], chi[10], yv ); + yv = _mm512_fmadd_pd( av[11], chi[11], yv ); + + // Store the result from the __m512 variable into the destination + _mm512_mask_storeu_pd( (double *)(y ), m_mask, yv ); + } + } + // To handle inputs that cannot be vectorized + else + { + double yc = *y; + double chi_s[12]; + + // The elements in the vector are multipled with alpha and the result is stored in an array + chi_s[0] = *(x + 0 * incx) * *alpha; + chi_s[1] = *(x + 1 * incx) * *alpha; + chi_s[2] = *(x + 2 * incx) * *alpha; + chi_s[3] = *(x + 3 * incx) * *alpha; + + chi_s[4] = *(x + 4 * incx) * *alpha; + chi_s[5] = *(x + 5 * incx) * *alpha; + chi_s[6] = *(x + 6 * incx) * *alpha; + chi_s[7] = *(x + 7 * incx) * *alpha; + + chi_s[8] = *(x + 8 * incx) * *alpha; + chi_s[9] = *(x + 9 * incx) * *alpha; + chi_s[10] = *(x + 10 * incx) * *alpha; + chi_s[11] = *(x + 11 * incx) * *alpha; + + + // A loop is used to iterate over the matrix row-by-row. + // The elements in each row are multipled with each value in the array + for ( i = 0; (i + 0) < m ; ++i ) + { + yc = *y; + + yc += chi_s[0] * (*as[0]); + as[0] += inca; + + yc += chi_s[1] * (*as[1]); + as[1] += inca; + + yc += chi_s[2] * (*as[2]); + as[2] += inca; + + yc += chi_s[3] * (*as[3]); + as[3] += inca; + + yc += chi_s[4] * (*as[4]); + as[4] += inca; + + yc += chi_s[5] * (*as[5]); + as[5] += inca; + + yc += chi_s[6] * (*as[6]); + as[6] += inca; + + yc += chi_s[7] * (*as[7]); + as[7] += inca; + + yc += chi_s[8] * (*as[8]); + as[8] += inca; + + yc += chi_s[9] * (*as[9]); + as[9] += inca; + + yc += chi_s[10] * (*as[10]); + as[10] += inca; + + yc += chi_s[11] * (*as[11]); + as[11] += inca; + + *y = yc; + y += incy; + } + } +} diff --git a/kernels/zen4/1f/bli_dotxf_zen_int_avx512.c b/kernels/zen4/1f/bli_dotxf_zen_int_avx512.c new file mode 100644 index 0000000000..fe573eb449 --- /dev/null +++ b/kernels/zen4/1f/bli_dotxf_zen_int_avx512.c @@ -0,0 +1,1940 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +#if defined __clang__ + #define UNROLL_LOOP_FULL() _Pragma("clang loop unroll(full)") +#elif defined __GNUC__ + #define UNROLL_LOOP_FULL() _Pragma("GCC unroll 8") +#else + #define UNROLL_LOOP_FULL() +#endif + +void bli_ddotxf_zen_int_avx512 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a_, inc_t inca, inc_t lda, + double* restrict x_, inc_t incx, + double* restrict beta, + double* restrict y_, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 8; + double* a = a_; + double* x = x_; + double* y = y_; + + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int_avx512( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + /* + If b_n is not equal to the fusing factor, then perform the entire + operation as dotxv or perform the operation using dotxf kernels with + lower fuse factor. + */ + if (b_n < fuse_fac) + { + if (b_n >= 4) + { + dim_t fuse = 4; + + bli_ddotxf_zen_int_4 + ( + conjat, + conjx, + m, + fuse, + alpha, + a, inca, lda, + x, incx, + beta, + y, incy, + cntx + ); + + // Increment the pointers + a = a + (fuse)*lda; + y = y + (fuse)*incy; + + // Decrement to point to the remaining compute left + b_n -= 4; + } + + if (b_n >= 2) + { + dim_t fuse = 2; + + bli_ddotxf_zen_int_2 + ( + conjat, + conjx, + m, + fuse, + alpha, + a, inca, lda, + x, incx, + beta, + y, incy, + cntx + ); + + // Increment the pointers + a = a + (fuse)*lda; + y = y + (fuse)*incy; + + b_n -= 2; + } + + if (b_n == 1) + { + double *a1 = a; + double *x1 = x; + double *psi1 = y; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + else if (b_n > fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + + __m512d yv; + __m512d rho[8]; + double *restrict av[8]; + __m512d xv; + rho[0] = _mm512_setzero_pd(); + + if ( inca == 1 && incx == 1 ) + { + __m512d a_vec[8]; + dim_t m_iter = m / ( n_elem_per_reg ); + + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + rho[ii] = _mm512_setzero_pd(); + av[ii] = a + ii * lda; + } + + for(dim_t i = 0; i < m_iter; ++i) + { + xv = _mm512_loadu_pd( x ); + + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + a_vec[ii] = _mm512_loadu_pd( av[ii] ); + av[ii] += n_elem_per_reg; + rho[ii] = _mm512_fmadd_pd(a_vec[ii], xv, rho[ii]); + } + x += n_elem_per_reg; + } + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + rho[0][ii] = _mm512_reduce_add_pd(rho[ii]); + } + m -= n_elem_per_reg * m_iter; + a += n_elem_per_reg * m_iter; + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + + if( m > 0) + { + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + av[ii] = a + ii * lda; + } + } + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + rho[0][0] += (*av[0]) * x0c; + rho[0][1] += (*av[1]) * x0c; + rho[0][2] += (*av[2]) * x0c; + rho[0][3] += (*av[3]) * x0c; + rho[0][4] += (*av[4]) * x0c; + rho[0][5] += (*av[5]) * x0c; + rho[0][6] += (*av[6]) * x0c; + rho[0][7] += (*av[7]) * x0c; + + x0 += incx; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; + av[5] += inca; + av[6] += inca; + av[7] += inca; + } + + // Broadcast the alpha scalar. + __m512d alphav = _mm512_set1_pd( *alpha ); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + yv = _mm512_mul_pd(alphav, rho[0]); + else + { + // Broadcast the beta scalar + __m512d betav = _mm512_set1_pd(*beta); + + // Load y. + if( incy == 1 ) + { + yv = _mm512_loadu_pd( y ); + } + else + { + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + yv[ii] = *(y + ii * incy); + } + } + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + yv = _mm512_mul_pd(betav, yv); + yv = _mm512_fmadd_pd(alphav, rho[0], yv); + } + + // Store the output. + if (incy == 1) + { + _mm512_storeu_pd(y, yv); + } + else + { + UNROLL_LOOP_FULL() + for (dim_t ii = 0; ii < 8; ++ii) + { + *(y + ii * incy) = yv[ii]; + } + } + +} + + + +/* Union data structure to access AVX-512 registers +* One 512-bit AVX register holds 8 DP elements. */ +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + +void bli_zdotxf_zen_int_2_avx512 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /* If the vectors are empty or if alpha is zero, return early */ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) + { + bli_zscalv_zen_int + ( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx + ); + + return; + } + + // If b_n is not equal to the fusing factor(2), then perform the entire + // operation with a dotxv kernel call. + if ( b_n != 2 ) + { + dcomplex* restrict a1 = a; + dcomplex* restrict x1 = x; + dcomplex* restrict psi1 = y; + + bli_zdotxv_zen_int_avx512 + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *restrict av[2]; + double *restrict x_temp = (double *)(x); + + av[0] = (double *)(a + 0 * lda); + av[1] = (double *)(a + 1 * lda); + + // Local memory to store the dot-products + dcomplex res[2] __attribute__((aligned(64))); + res[0] = res[1] = (*bli_z0); + + // Performing XOR of conjx and conjat. + // conj_op is set if either X or A has conjugate(not both) + conj_t conj_op = conjx ^ conjat; + + // Computation for unit-strided case + if (incx == 1 && inca == 1) + { + // Declaring 4 registers, to store partial sums over multiple loads + // Further declaring 2 registers for load, 2 for broadcast(real and imag) + v8df_t rhov[4], a_vec[2], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm512_setzero_pd(); + rhov[1].v = _mm512_setzero_pd(); + rhov[2].v = _mm512_setzero_pd(); + rhov[3].v = _mm512_setzero_pd(); + + for (; (i + 3) < m; i += 4) + { + // Load 4 elements from X + xv[0].v = _mm512_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load 4 elements from first 4 columns of A + a_vec[0].v = _mm512_loadu_pd(av[0]); + a_vec[1].v = _mm512_loadu_pd(av[1]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[2].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[2].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[3].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[3].v); + + // Adjust the pointers accordingly + av[0] += 8; + av[1] += 8; + + x_temp += 8; + } + if (i < m) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(m-i) elements. + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + + // Load remaining elements from X + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + xv[0].v = _mm512_maskz_loadu_pd(m_mask, x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load remaining elements from first 4 columns of A + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + a_vec[0].v = _mm512_maskz_loadu_pd(m_mask, av[0]); + a_vec[1].v = _mm512_maskz_loadu_pd(m_mask, av[1]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[2].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[2].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[3].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[3].v); + } + + // Permuting for final accumulation of real and imag parts + rhov[2].v = _mm512_permute_pd(rhov[2].v, 0x55); + rhov[3].v = _mm512_permute_pd(rhov[3].v, 0x55); + + v8df_t scale_one; + v4df_t zero_reg; + + zero_reg.v = _mm256_setzero_pd(); + scale_one.v = _mm512_set1_pd(1.0); + + /* + conj_op maps to the compute as follows : + A = (a + ib), X = (x + iy) + ----------------------------------------------------------- + | A | X | Real part | Imag Part | + ----------------------------------------------------------- + | No-Conjugate | No-Conjugate | ax - by | bx + ay | + | No-Conjugate | Conjugate | ax + by | bx - ay | + | Conjugate | No-Conjugate | ax + by | -(bx - ay) | + | Conjugate | Conjugate | ax - by | -(bx + ay) | + ----------------------------------------------------------- + + If only X or A has conjugate, fmsubadd is performed. + Else, fmaddsub is performed. + + In the final reduction step, the imaginary part of every + partial sum is negated if conjat is conjugate + */ + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm512_fmaddsub_pd(scale_one.v, rhov[0].v, rhov[2].v); + rhov[1].v = _mm512_fmaddsub_pd(scale_one.v, rhov[1].v, rhov[3].v); + } + else + { + rhov[0].v = _mm512_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[2].v); + rhov[1].v = _mm512_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[3].v); + } + + // rhov[0 ... 1] will have the element wise product. + // These have to be added horizontally(reduction) to get the + // final result for every element in y. + // If rhov[0] = R0 I0 R1 I1 R2 I2 R3 I3 + // Then rhov[2] = R1 I1 R0 I0 R3 I2 R2 I2 + rhov[2].v = _mm512_permutex_pd(rhov[0].v, 0x4E); + rhov[3].v = _mm512_permutex_pd(rhov[1].v, 0x4E); + + // rhov[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + // (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + rhov[0].v = _mm512_add_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm512_add_pd(rhov[1].v, rhov[3].v); + + // 256-bit registers declared to extract 256-bit lanes + v4df_t reduce_sum[4]; + + // reduce_sum[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + reduce_sum[0].v = _mm512_extractf64x4_pd(rhov[0].v, 0x00); + reduce_sum[1].v = _mm512_extractf64x4_pd(rhov[1].v, 0x00); + + // reduce_sum[2] = (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + reduce_sum[2].v = _mm512_extractf64x4_pd(rhov[0].v, 0x1); + reduce_sum[3].v = _mm512_extractf64x4_pd(rhov[1].v, 0x1); + + // reduce_sum[0] = (R0 + R1 + R2 + R3) (I0 + I1 + I2 + I3) ... + reduce_sum[0].v = _mm256_add_pd(reduce_sum[0].v, reduce_sum[2].v); + reduce_sum[1].v = _mm256_add_pd(reduce_sum[1].v, reduce_sum[3].v); + + // The next set of shuffles and permutes are performed to store + // all the dot-products onto one 256-bit register. This is used to + // perform aligned stores onto the stack memory. + reduce_sum[2].v = _mm256_shuffle_pd(reduce_sum[0].v, reduce_sum[1].v, 0xC); + + reduce_sum[3].v = _mm256_permutex_pd(reduce_sum[2].v, 0xD8); + + // Negate the sign bit of imaginary part of dot-products if conjat is conjugate + if ( bli_is_conj( conjat ) ) + { + reduce_sum[3].v = _mm256_fmsubadd_pd(zero_reg.v, zero_reg.v, reduce_sum[3].v); + } + + /* + Computed dot product result is being stored + in temp buffer r for further computation. + */ + _mm256_store_pd((double *)res, reduce_sum[3].v); + } + + // This section will have the whole of compute when incx != 1 || inca != 1 + else + { + // Declaring 128-bit registers, for element by element computation + v2df_t rhov[4], a_vec[2], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm_setzero_pd(); + rhov[1].v = _mm_setzero_pd(); + rhov[2].v = _mm_setzero_pd(); + rhov[3].v = _mm_setzero_pd(); + + for (dim_t i = 0; i < m; i++) + { + // Load from X + xv[0].v = _mm_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + xv[1].v = _mm_permute_pd(xv[0].v, 0b11); + + // Permute to duplicate the real part for every element + xv[0].v = _mm_permute_pd(xv[0].v, 0b00); + + // Load elements from first 4 columns of A + a_vec[0].v = _mm_loadu_pd(av[0]); + a_vec[1].v = _mm_loadu_pd(av[1]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[2].v = _mm_fmadd_pd(a_vec[0].v, xv[1].v, rhov[2].v); + + rhov[1].v = _mm_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[3].v = _mm_fmadd_pd(a_vec[1].v, xv[1].v, rhov[3].v); + + av[0] += 2 * inca; + av[1] += 2 * inca; + + x_temp += 2 * incx; + } + + // Permuting to help with final reduction + rhov[3].v = _mm_permute_pd(rhov[3].v, 0b01); + rhov[2].v = _mm_permute_pd(rhov[2].v, 0b01); + + v2df_t zero_reg, scale_one; + + zero_reg.v = _mm_setzero_pd(); + scale_one.v = _mm_set1_pd(1.0); + + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm_addsub_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm_addsub_pd(rhov[1].v, rhov[3].v); + } + else + { + rhov[0].v = _mm_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[2].v); + rhov[1].v = _mm_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[3].v); + } + if( bli_is_conj( conjat ) ) + { + rhov[0].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + rhov[1].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[1].v); + } + + // Storing onto static memory, to be used later + _mm_storeu_pd((double *)res, rhov[0].v); + _mm_storeu_pd((double *)(res + 1), rhov[1].v); + + } + + // Scaling by alpha + // Registers to load partial sums, stored in static memory + v4df_t rhov, temp; + + rhov.v = _mm256_load_pd((double *)res); + + if ( !bli_zeq1( *alpha ) ) + { + __m256d alphaRv, alphaIv; + alphaRv = _mm256_set1_pd((*alpha).real); + alphaIv = _mm256_set1_pd((*alpha).imag); + + temp.v = _mm256_permute_pd(rhov.v, 0x5); + + // Scaling with imag part of alpha + temp.v = _mm256_mul_pd(temp.v, alphaIv); + + // Scaling with real part of alpha, and addsub + rhov.v = _mm256_fmaddsub_pd(rhov.v, alphaRv, temp.v); + } + // When 'beta' is not zero we need to multiply scale 'y' by 'beta' + v4df_t yv; + + yv.v = _mm256_setzero_pd(); + + if (!PASTEMAC(z, eq0)(*beta)) + { + __m256d betaRv, betaIv; + + betaRv = _mm256_set1_pd((*beta).real); + betaIv = _mm256_set1_pd((*beta).imag); + + if (incy == 1) + { + yv.v = _mm256_loadu_pd((double *)(y)); + } + else + { + /* + This can be done using SSE instructions + but has been kept as scalar code to avoid + mixing SSE with AVX + */ + yv.d[0] = (*(y + 0 * incy)).real; + yv.d[1] = (*(y + 0 * incy)).imag; + yv.d[2] = (*(y + 1 * incy)).real; + yv.d[3] = (*(y + 1 * incy)).imag; + + } + + temp.v = _mm256_permute_pd(yv.v, 0x5); + + // Scaling with imag part of alpha + temp.v = _mm256_mul_pd(temp.v, betaIv); + + // Scaling with real part of alpha, and addsub + yv.v = _mm256_fmaddsub_pd(yv.v, betaRv, temp.v); + } + + // Adding alpha*A*x to beta*Y + yv.v = _mm256_add_pd(yv.v, rhov.v); + + if (incy == 1) + { + _mm256_storeu_pd((double *)y, yv.v); + } + else + { + (*(y + 0 * incy)).real = yv.d[0]; + (*(y + 0 * incy)).imag = yv.d[1]; + (*(y + 1 * incy)).real = yv.d[2]; + (*(y + 1 * incy)).imag = yv.d[3]; + + } + +} + +void bli_zdotxf_zen_int_4_avx512 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /* If the vectors are empty or if alpha is zero, return early */ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) + { + bli_zscalv_zen_int + ( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx + ); + + return; + } + + // If b_n is not equal to the fusing factor(4), then perform the entire + // operation as a sequence of fringe dotxf kernel(2) and dotxv + // kernel as per the requirement. + if ( b_n != 4 ) + { + dcomplex* restrict a1 = a; + dcomplex* restrict x1 = x; + dcomplex* restrict psi1 = y; + + if( b_n >= 2 ) + { + bli_zdotxf_zen_int_2_avx512 + ( + conjat, + conjx, + m, + (dim_t)2, + alpha, + a1, inca, lda, + x1, incx, + beta, + psi1, incy, + NULL + ); + + a1 += 2*lda; + psi1 += 2*incy; + + b_n -= 2; + } + + if( b_n == 1 ) + { + bli_zdotxv_zen_int_avx512 + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *restrict av[4]; + double *restrict x_temp = (double *)(x); + + av[0] = (double *)(a + 0 * lda); + av[1] = (double *)(a + 1 * lda); + av[2] = (double *)(a + 2 * lda); + av[3] = (double *)(a + 3 * lda); + + // Local memory to store the dot-products + dcomplex res[4] __attribute__((aligned(64))); + res[0] = res[1] = res[2] = res[3] = (*bli_z0); + + // Performing XOR of conjx and conjat. + // conj_op is set if either X or A has conjugate(not both) + conj_t conj_op = conjx ^ conjat; + + // Computation for unit-strided case + if (incx == 1 && inca == 1) + { + // Declaring 8 registers, to store partial sums over multiple loads + // Further declaring 4 registers for load, 2 for broadcast(real and imag) + v8df_t rhov[8], a_vec[4], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm512_setzero_pd(); + rhov[1].v = _mm512_setzero_pd(); + rhov[2].v = _mm512_setzero_pd(); + rhov[3].v = _mm512_setzero_pd(); + rhov[4].v = _mm512_setzero_pd(); + rhov[5].v = _mm512_setzero_pd(); + rhov[6].v = _mm512_setzero_pd(); + rhov[7].v = _mm512_setzero_pd(); + + for (; (i + 3) < m; i += 4) + { + // Load 4 elements from X + xv[0].v = _mm512_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load 4 elements from first 4 columns of A + a_vec[0].v = _mm512_loadu_pd(av[0]); + a_vec[1].v = _mm512_loadu_pd(av[1]); + a_vec[2].v = _mm512_loadu_pd(av[2]); + a_vec[3].v = _mm512_loadu_pd(av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[4].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[4].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[5].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[5].v); + + rhov[2].v = _mm512_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[6].v = _mm512_fmadd_pd(a_vec[2].v, xv[1].v, rhov[6].v); + + rhov[3].v = _mm512_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[7].v = _mm512_fmadd_pd(a_vec[3].v, xv[1].v, rhov[7].v); + + // Adjust the pointers accordingly + av[0] += 8; + av[1] += 8; + av[2] += 8; + av[3] += 8; + + x_temp += 8; + } + if (i < m) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(m-i) elements. + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + + // Load remaining elements from X + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + xv[0].v = _mm512_maskz_loadu_pd(m_mask, x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load remaining elements from first 4 columns of A + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + a_vec[0].v = _mm512_maskz_loadu_pd(m_mask, av[0]); + a_vec[1].v = _mm512_maskz_loadu_pd(m_mask, av[1]); + a_vec[2].v = _mm512_maskz_loadu_pd(m_mask, av[2]); + a_vec[3].v = _mm512_maskz_loadu_pd(m_mask, av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[4].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[4].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[5].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[5].v); + + rhov[2].v = _mm512_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[6].v = _mm512_fmadd_pd(a_vec[2].v, xv[1].v, rhov[6].v); + + rhov[3].v = _mm512_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[7].v = _mm512_fmadd_pd(a_vec[3].v, xv[1].v, rhov[7].v); + } + + // Permuting for final accumulation of real and imag parts + rhov[4].v = _mm512_permute_pd(rhov[4].v, 0x55); + rhov[5].v = _mm512_permute_pd(rhov[5].v, 0x55); + rhov[6].v = _mm512_permute_pd(rhov[6].v, 0x55); + rhov[7].v = _mm512_permute_pd(rhov[7].v, 0x55); + + // Setting 2 registers to 0 and 1 + v8df_t zero_reg, scale_one; + + zero_reg.v = _mm512_setzero_pd(); + scale_one.v = _mm512_set1_pd(1.0); + + /* + conj_op maps to the compute as follows : + A = (a + ib), X = (x + iy) + ----------------------------------------------------------- + | A | X | Real part | Imag Part | + ----------------------------------------------------------- + | No-Conjugate | No-Conjugate | ax - by | bx + ay | + | No-Conjugate | Conjugate | ax + by | bx - ay | + | Conjugate | No-Conjugate | ax + by | -(bx - ay) | + | Conjugate | Conjugate | ax - by | -(bx + ay) | + ----------------------------------------------------------- + + If only X or A has conjugate, fmsubadd is performed. + Else, fmaddsub is performed. + + In the final reduction step, the imaginary part of every + partial sum is negated if conjat is conjugate + */ + + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm512_fmaddsub_pd(scale_one.v, rhov[0].v, rhov[4].v); + rhov[1].v = _mm512_fmaddsub_pd(scale_one.v, rhov[1].v, rhov[5].v); + rhov[2].v = _mm512_fmaddsub_pd(scale_one.v, rhov[2].v, rhov[6].v); + rhov[3].v = _mm512_fmaddsub_pd(scale_one.v, rhov[3].v, rhov[7].v); + } + else + { + rhov[0].v = _mm512_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[4].v); + rhov[1].v = _mm512_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[5].v); + rhov[2].v = _mm512_fmsubadd_pd(scale_one.v, rhov[2].v, rhov[6].v); + rhov[3].v = _mm512_fmsubadd_pd(scale_one.v, rhov[3].v, rhov[7].v); + } + + // rhov[0 ... 3] will have the element wise product. + // These have to be added horizontally(reduction) to get the + // final result for every element in y. + // If rhov[0] = R0 I0 R1 I1 R2 I2 R3 I3 + // Then rhov[4] = R1 I1 R0 I0 R3 I2 R2 I2 + rhov[4].v = _mm512_permutex_pd(rhov[0].v, 0x4E); + rhov[5].v = _mm512_permutex_pd(rhov[1].v, 0x4E); + rhov[6].v = _mm512_permutex_pd(rhov[2].v, 0x4E); + rhov[7].v = _mm512_permutex_pd(rhov[3].v, 0x4E); + + // rhov[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + // (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + rhov[0].v = _mm512_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm512_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm512_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm512_add_pd(rhov[3].v, rhov[7].v); + + // 256-bit registers declared to extract 256-bit lanes + v4df_t reduce_sum[8]; + + // reduce_sum[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + reduce_sum[0].v = _mm512_extractf64x4_pd(rhov[0].v, 0x00); + reduce_sum[1].v = _mm512_extractf64x4_pd(rhov[1].v, 0x00); + reduce_sum[2].v = _mm512_extractf64x4_pd(rhov[2].v, 0x00); + reduce_sum[3].v = _mm512_extractf64x4_pd(rhov[3].v, 0x00); + + // reduce_sum[4] = (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + reduce_sum[4].v = _mm512_extractf64x4_pd(rhov[0].v, 0x1); + reduce_sum[5].v = _mm512_extractf64x4_pd(rhov[1].v, 0x1); + reduce_sum[6].v = _mm512_extractf64x4_pd(rhov[2].v, 0x1); + reduce_sum[7].v = _mm512_extractf64x4_pd(rhov[3].v, 0x1); + + // reduce_sum[0] = (R0 + R1 + R2 + R3) (I0 + I1 + I2 + I3) ... + reduce_sum[0].v = _mm256_add_pd(reduce_sum[0].v, reduce_sum[4].v); + reduce_sum[1].v = _mm256_add_pd(reduce_sum[1].v, reduce_sum[5].v); + reduce_sum[2].v = _mm256_add_pd(reduce_sum[2].v, reduce_sum[6].v); + reduce_sum[3].v = _mm256_add_pd(reduce_sum[3].v, reduce_sum[7].v); + + // The next set of shuffles, permutes and inserts are performed to store + // all the dot-products onto one 512-bit register. This is used to perform + // aligned stores onto the stack memory. + reduce_sum[4].v = _mm256_shuffle_pd(reduce_sum[0].v, reduce_sum[1].v, 0xC); + reduce_sum[5].v = _mm256_shuffle_pd(reduce_sum[2].v, reduce_sum[3].v, 0xC); + + reduce_sum[6].v = _mm256_permutex_pd(reduce_sum[4].v, 0xD8); + reduce_sum[7].v = _mm256_permutex_pd(reduce_sum[5].v, 0xD8); + + rhov[0].v = _mm512_insertf64x4(rhov[0].v, reduce_sum[6].v, 0x00); + rhov[0].v = _mm512_insertf64x4(rhov[0].v, reduce_sum[7].v, 0x01); + + // Negate the sign bit of imaginary part of dot-products if conjat is conjugate + if ( bli_is_conj( conjat ) ) + { + rhov[0].v = _mm512_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + } + + /* + Computed dot product result is being stored + in temp buffer r for further computation. + */ + _mm512_store_pd((double *)res, rhov[0].v); + } + + // This section will have the whole of compute when incx != 1 || inca != 1 + else + { + // Declaring 128-bit registers, for element by element computation + v2df_t rhov[8], a_vec[4], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm_setzero_pd(); + rhov[1].v = _mm_setzero_pd(); + rhov[2].v = _mm_setzero_pd(); + rhov[3].v = _mm_setzero_pd(); + rhov[4].v = _mm_setzero_pd(); + rhov[5].v = _mm_setzero_pd(); + rhov[6].v = _mm_setzero_pd(); + rhov[7].v = _mm_setzero_pd(); + + for (dim_t i = 0; i < m; i++) + { + // Load from X + xv[0].v = _mm_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + xv[1].v = _mm_permute_pd(xv[0].v, 0b11); + + // Permute to duplicate the real part for every element + xv[0].v = _mm_permute_pd(xv[0].v, 0b00); + + // Load elements from first 4 columns of A + a_vec[0].v = _mm_loadu_pd(av[0]); + a_vec[1].v = _mm_loadu_pd(av[1]); + a_vec[2].v = _mm_loadu_pd(av[2]); + a_vec[3].v = _mm_loadu_pd(av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[4].v = _mm_fmadd_pd(a_vec[0].v, xv[1].v, rhov[4].v); + + rhov[1].v = _mm_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[5].v = _mm_fmadd_pd(a_vec[1].v, xv[1].v, rhov[5].v); + + rhov[2].v = _mm_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[6].v = _mm_fmadd_pd(a_vec[2].v, xv[1].v, rhov[6].v); + + rhov[3].v = _mm_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[7].v = _mm_fmadd_pd(a_vec[3].v, xv[1].v, rhov[7].v); + + av[0] += 2 * inca; + av[1] += 2 * inca; + av[2] += 2 * inca; + av[3] += 2 * inca; + + x_temp += 2 * incx; + } + + // Permuting to help with final reduction + rhov[4].v = _mm_permute_pd(rhov[4].v, 0b01); + rhov[5].v = _mm_permute_pd(rhov[5].v, 0b01); + rhov[6].v = _mm_permute_pd(rhov[6].v, 0b01); + rhov[7].v = _mm_permute_pd(rhov[7].v, 0b01); + + v2df_t zero_reg, scale_one; + + zero_reg.v = _mm_setzero_pd(); + scale_one.v = _mm_set1_pd(1.0); + + // Reduction based on conj_op + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm_addsub_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm_addsub_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm_addsub_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm_addsub_pd(rhov[3].v, rhov[7].v); + } + else + { + rhov[0].v = _mm_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[4].v); + rhov[1].v = _mm_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[5].v); + rhov[2].v = _mm_fmsubadd_pd(scale_one.v, rhov[2].v, rhov[6].v); + rhov[3].v = _mm_fmsubadd_pd(scale_one.v, rhov[3].v, rhov[7].v); + } + if( bli_is_conj( conjat ) ) + { + rhov[0].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + rhov[1].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[1].v); + rhov[2].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[2].v); + rhov[3].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[3].v); + } + + // Storing onto stack memory + _mm_storeu_pd((double *)res, rhov[0].v); + _mm_storeu_pd((double *)(res + 1), rhov[1].v); + _mm_storeu_pd((double *)(res + 2), rhov[2].v); + _mm_storeu_pd((double *)(res + 3), rhov[3].v); + + } + + // Scaling by alpha + // Registers to load partial sums, stored in static memory + v8df_t rhov, temp; + + rhov.v = _mm512_loadu_pd((double *)res); + + if ( !bli_zeq1( *alpha ) ) + { + __m512d alphaRv, alphaIv; + alphaRv = _mm512_set1_pd((*alpha).real); + alphaIv = _mm512_set1_pd((*alpha).imag); + + temp.v = _mm512_permute_pd(rhov.v, 0x55); + + // Scaling with imag part of alpha + temp.v = _mm512_mul_pd(temp.v, alphaIv); + + // Scaling with real part of alpha, and addsub + rhov.v = _mm512_fmaddsub_pd(rhov.v, alphaRv, temp.v); + } + // When 'beta' is not zero we need to multiply scale 'y' by 'beta' + v8df_t yv; + + yv.v = _mm512_setzero_pd(); + + if (!PASTEMAC(z, eq0)(*beta)) + { + __m512d betaRv, betaIv; + + betaRv = _mm512_set1_pd((*beta).real); + betaIv = _mm512_set1_pd((*beta).imag); + + if (incy == 1) + { + yv.v = _mm512_loadu_pd((double *)(y)); + } + else + { + /* + This can be done using SSE instructions + but has been kept as scalar code to avoid + mixing SSE with AVX + */ + yv.d[0] = (*(y + 0 * incy)).real; + yv.d[1] = (*(y + 0 * incy)).imag; + yv.d[2] = (*(y + 1 * incy)).real; + yv.d[3] = (*(y + 1 * incy)).imag; + yv.d[4] = (*(y + 2 * incy)).real; + yv.d[5] = (*(y + 2 * incy)).imag; + yv.d[6] = (*(y + 3 * incy)).real; + yv.d[7] = (*(y + 3 * incy)).imag; + + } + + temp.v = _mm512_permute_pd(yv.v, 0x55); + + // Scaling with imag part of alpha + temp.v = _mm512_mul_pd(temp.v, betaIv); + + // Scaling with real part of alpha, and addsub + yv.v = _mm512_fmaddsub_pd(yv.v, betaRv, temp.v); + } + + // Adding alpha*A*x to beta*Y + yv.v = _mm512_add_pd(yv.v, rhov.v); + + if (incy == 1) + { + _mm512_storeu_pd((double *)y, yv.v); + } + else + { + (*(y + 0 * incy)).real = yv.d[0]; + (*(y + 0 * incy)).imag = yv.d[1]; + (*(y + 1 * incy)).real = yv.d[2]; + (*(y + 1 * incy)).imag = yv.d[3]; + + (*(y + 2 * incy)).real = yv.d[4]; + (*(y + 2 * incy)).imag = yv.d[5]; + (*(y + 3 * incy)).real = yv.d[6]; + (*(y + 3 * incy)).imag = yv.d[7]; + + } + +} + +void bli_zdotxf_zen_int_8_avx512 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /* If vectors are empty or if alpha is zero, scale y by beta and return */ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) + { + bli_zscalv_zen_int + ( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx + ); + + return; + } + + // If b_n is not equal to the fusing factor(8), then perform the entire + // operation as a sequence of fringe dotxf kernels(4 and 2) and dotxv + // kernel as per the requirement. + if ( b_n != 8 ) + { + dcomplex* restrict a1 = a; + dcomplex* restrict x1 = x; + dcomplex* restrict psi1 = y; + + if( b_n >= 4 ) + { + bli_zdotxf_zen_int_4_avx512 + ( + conjat, + conjx, + m, + (dim_t)4, + alpha, + a1, inca, lda, + x1, incx, + beta, + psi1, incy, + NULL + ); + + a1 += 4*lda; + psi1 += 4*incy; + + b_n -= 4; + } + + if( b_n >= 2 ) + { + bli_zdotxf_zen_int_2_avx512 + ( + conjat, + conjx, + m, + (dim_t)2, + alpha, + a1, inca, lda, + x1, incx, + beta, + psi1, incy, + NULL + ); + + a1 += 2*lda; + psi1 += 2*incy; + + b_n -= 2; + } + + if( b_n == 1 ) + { + bli_zdotxv_zen_int_avx512 + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + + return; + } + + // Declaring and initializing the iterator and pointers + dim_t i = 0; + + double *restrict av[8]; + double *restrict x_temp = (double *)(x); + + av[0] = (double *)(a + 0 * lda); + av[1] = (double *)(a + 1 * lda); + av[2] = (double *)(a + 2 * lda); + av[3] = (double *)(a + 3 * lda); + av[4] = (double *)(a + 4 * lda); + av[5] = (double *)(a + 5 * lda); + av[6] = (double *)(a + 6 * lda); + av[7] = (double *)(a + 7 * lda); + + // Local memory to store the dot-products + dcomplex res[8] __attribute__((aligned(64))); + res[0] = res[1] = res[2] = res[3] = res[4] = res[5] = res[6] = res[7] = (*bli_z0); + + // Performing XOR of conjx and conjat. + // conj_op is set if either X or A has conjugate(not both) + conj_t conj_op = conjx ^ conjat; + + // Computation for unit-strided case + if (incx == 1 && inca == 1) + { + // Declaring 16 registers, to store partial sums over multiple loads + // Further declaring 8 registers for load, 2 for broadcast(real and imag) + v8df_t rhov[16], a_vec[8], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm512_setzero_pd(); + rhov[1].v = _mm512_setzero_pd(); + rhov[2].v = _mm512_setzero_pd(); + rhov[3].v = _mm512_setzero_pd(); + rhov[4].v = _mm512_setzero_pd(); + rhov[5].v = _mm512_setzero_pd(); + rhov[6].v = _mm512_setzero_pd(); + rhov[7].v = _mm512_setzero_pd(); + rhov[8].v = _mm512_setzero_pd(); + rhov[9].v = _mm512_setzero_pd(); + rhov[10].v = _mm512_setzero_pd(); + rhov[11].v = _mm512_setzero_pd(); + rhov[12].v = _mm512_setzero_pd(); + rhov[13].v = _mm512_setzero_pd(); + rhov[14].v = _mm512_setzero_pd(); + rhov[15].v = _mm512_setzero_pd(); + + for (; (i + 3) < m; i += 4) + { + // Load 4 elements from X + xv[0].v = _mm512_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load 4 elements from first 4 columns of A + a_vec[0].v = _mm512_loadu_pd(av[0]); + a_vec[1].v = _mm512_loadu_pd(av[1]); + a_vec[2].v = _mm512_loadu_pd(av[2]); + a_vec[3].v = _mm512_loadu_pd(av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[8].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[8].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[9].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[9].v); + + rhov[2].v = _mm512_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[10].v = _mm512_fmadd_pd(a_vec[2].v, xv[1].v, rhov[10].v); + + rhov[3].v = _mm512_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[11].v = _mm512_fmadd_pd(a_vec[3].v, xv[1].v, rhov[11].v); + + // Load 4 elements from next 4 columns of A + a_vec[4].v = _mm512_loadu_pd(av[4]); + a_vec[5].v = _mm512_loadu_pd(av[5]); + a_vec[6].v = _mm512_loadu_pd(av[6]); + a_vec[7].v = _mm512_loadu_pd(av[7]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[4].v = _mm512_fmadd_pd(a_vec[4].v, xv[0].v, rhov[4].v); + rhov[12].v = _mm512_fmadd_pd(a_vec[4].v, xv[1].v, rhov[12].v); + + rhov[5].v = _mm512_fmadd_pd(a_vec[5].v, xv[0].v, rhov[5].v); + rhov[13].v = _mm512_fmadd_pd(a_vec[5].v, xv[1].v, rhov[13].v); + + rhov[6].v = _mm512_fmadd_pd(a_vec[6].v, xv[0].v, rhov[6].v); + rhov[14].v = _mm512_fmadd_pd(a_vec[6].v, xv[1].v, rhov[14].v); + + rhov[7].v = _mm512_fmadd_pd(a_vec[7].v, xv[0].v, rhov[7].v); + rhov[15].v = _mm512_fmadd_pd(a_vec[7].v, xv[1].v, rhov[15].v); + + // Adjust the pointers accordingly + av[0] += 8; + av[1] += 8; + av[2] += 8; + av[3] += 8; + av[4] += 8; + av[5] += 8; + av[6] += 8; + av[7] += 8; + + x_temp += 8; + } + if (i < m) + { + // Setting the mask bit based on remaining elements + // Since each dcomplex elements corresponds to 2 doubles + // we need to load and store 2*(m-i) elements. + __mmask8 m_mask = (1 << 2*(m - i)) - 1; + + // Load remaining elements from X + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + xv[0].v = _mm512_maskz_loadu_pd(m_mask, x_temp); + + // Permute to duplicate the imag part for every element + // xv[1].v = I0 I0 I1 I1 ... + xv[1].v = _mm512_permute_pd(xv[0].v, 0xFF); + + // Permute to duplicate the real part for every element + // xv[0].v = R0 R0 R1 R1 ... + xv[0].v = _mm512_permute_pd(xv[0].v, 0x00); + + // Load remaining elements from first 4 columns of A + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + a_vec[0].v = _mm512_maskz_loadu_pd(m_mask, av[0]); + a_vec[1].v = _mm512_maskz_loadu_pd(m_mask, av[1]); + a_vec[2].v = _mm512_maskz_loadu_pd(m_mask, av[2]); + a_vec[3].v = _mm512_maskz_loadu_pd(m_mask, av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm512_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[8].v = _mm512_fmadd_pd(a_vec[0].v, xv[1].v, rhov[8].v); + + rhov[1].v = _mm512_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[9].v = _mm512_fmadd_pd(a_vec[1].v, xv[1].v, rhov[9].v); + + rhov[2].v = _mm512_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[10].v = _mm512_fmadd_pd(a_vec[2].v, xv[1].v, rhov[10].v); + + rhov[3].v = _mm512_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[11].v = _mm512_fmadd_pd(a_vec[3].v, xv[1].v, rhov[11].v); + + // Load remaining elements from next 4 columns of A + // Maskz_load is used to ensure the unloaded elements are 0 + // Else, it affects the accumulation and final reduction + a_vec[4].v = _mm512_maskz_loadu_pd(m_mask, av[4]); + a_vec[5].v = _mm512_maskz_loadu_pd(m_mask, av[5]); + a_vec[6].v = _mm512_maskz_loadu_pd(m_mask, av[6]); + a_vec[7].v = _mm512_maskz_loadu_pd(m_mask, av[7]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[4].v = _mm512_fmadd_pd(a_vec[4].v, xv[0].v, rhov[4].v); + rhov[12].v = _mm512_fmadd_pd(a_vec[4].v, xv[1].v, rhov[12].v); + + rhov[5].v = _mm512_fmadd_pd(a_vec[5].v, xv[0].v, rhov[5].v); + rhov[13].v = _mm512_fmadd_pd(a_vec[5].v, xv[1].v, rhov[13].v); + + rhov[6].v = _mm512_fmadd_pd(a_vec[6].v, xv[0].v, rhov[6].v); + rhov[14].v = _mm512_fmadd_pd(a_vec[6].v, xv[1].v, rhov[14].v); + + rhov[7].v = _mm512_fmadd_pd(a_vec[7].v, xv[0].v, rhov[7].v); + rhov[15].v = _mm512_fmadd_pd(a_vec[7].v, xv[1].v, rhov[15].v); + } + + // Permuting for final accumulation of real and imag parts + rhov[8].v = _mm512_permute_pd(rhov[8].v, 0x55); + rhov[9].v = _mm512_permute_pd(rhov[9].v, 0x55); + rhov[10].v = _mm512_permute_pd(rhov[10].v, 0x55); + rhov[11].v = _mm512_permute_pd(rhov[11].v, 0x55); + rhov[12].v = _mm512_permute_pd(rhov[12].v, 0x55); + rhov[13].v = _mm512_permute_pd(rhov[13].v, 0x55); + rhov[14].v = _mm512_permute_pd(rhov[14].v, 0x55); + rhov[15].v = _mm512_permute_pd(rhov[15].v, 0x55); + + // Setting 2 registers to 0 and 1 + v8df_t zero_reg, scale_one; + + zero_reg.v = _mm512_setzero_pd(); + scale_one.v = _mm512_set1_pd(1.0); + + /* + conj_op maps to the compute as follows : + A = (a + ib), X = (x + iy) + ----------------------------------------------------------- + | A | X | Real part | Imag Part | + ----------------------------------------------------------- + | No-Conjugate | No-Conjugate | ax - by | bx + ay | + | No-Conjugate | Conjugate | ax + by | bx - ay | + | Conjugate | No-Conjugate | ax + by | -(bx - ay) | + | Conjugate | Conjugate | ax - by | -(bx + ay) | + ----------------------------------------------------------- + + If only X or A has conjugate, fmsubadd is performed. + Else, fmaddsub is performed. + + In the final reduction step, the imaginary part of every + partial sum is negated if conjat is conjugate + */ + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm512_fmaddsub_pd(scale_one.v, rhov[0].v, rhov[8].v); + rhov[1].v = _mm512_fmaddsub_pd(scale_one.v, rhov[1].v, rhov[9].v); + rhov[2].v = _mm512_fmaddsub_pd(scale_one.v, rhov[2].v, rhov[10].v); + rhov[3].v = _mm512_fmaddsub_pd(scale_one.v, rhov[3].v, rhov[11].v); + rhov[4].v = _mm512_fmaddsub_pd(scale_one.v, rhov[4].v, rhov[12].v); + rhov[5].v = _mm512_fmaddsub_pd(scale_one.v, rhov[5].v, rhov[13].v); + rhov[6].v = _mm512_fmaddsub_pd(scale_one.v, rhov[6].v, rhov[14].v); + rhov[7].v = _mm512_fmaddsub_pd(scale_one.v, rhov[7].v, rhov[15].v); + } + else + { + rhov[0].v = _mm512_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[8].v); + rhov[1].v = _mm512_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[9].v); + rhov[2].v = _mm512_fmsubadd_pd(scale_one.v, rhov[2].v, rhov[10].v); + rhov[3].v = _mm512_fmsubadd_pd(scale_one.v, rhov[3].v, rhov[11].v); + rhov[4].v = _mm512_fmsubadd_pd(scale_one.v, rhov[4].v, rhov[12].v); + rhov[5].v = _mm512_fmsubadd_pd(scale_one.v, rhov[5].v, rhov[13].v); + rhov[6].v = _mm512_fmsubadd_pd(scale_one.v, rhov[6].v, rhov[14].v); + rhov[7].v = _mm512_fmsubadd_pd(scale_one.v, rhov[7].v, rhov[15].v); + } + + // rhov[0 ... 7] will have the element wise product. + // These have to be added horizontally(reduction) to get the + // final result for every element in y. + // If rhov[0] = R0 I0 R1 I1 R2 I2 R3 I3 + // Then rhov[8] = R1 I1 R0 I0 R3 I2 R2 I2 + rhov[8].v = _mm512_permutex_pd(rhov[0].v, 0x4E); + rhov[9].v = _mm512_permutex_pd(rhov[1].v, 0x4E); + rhov[10].v = _mm512_permutex_pd(rhov[2].v, 0x4E); + rhov[11].v = _mm512_permutex_pd(rhov[3].v, 0x4E); + rhov[12].v = _mm512_permutex_pd(rhov[4].v, 0x4E); + rhov[13].v = _mm512_permutex_pd(rhov[5].v, 0x4E); + rhov[14].v = _mm512_permutex_pd(rhov[6].v, 0x4E); + rhov[15].v = _mm512_permutex_pd(rhov[7].v, 0x4E); + + // rhov[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + // (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + rhov[0].v = _mm512_add_pd(rhov[0].v, rhov[8].v); + rhov[1].v = _mm512_add_pd(rhov[1].v, rhov[9].v); + rhov[2].v = _mm512_add_pd(rhov[2].v, rhov[10].v); + rhov[3].v = _mm512_add_pd(rhov[3].v, rhov[11].v); + rhov[4].v = _mm512_add_pd(rhov[4].v, rhov[12].v); + rhov[5].v = _mm512_add_pd(rhov[5].v, rhov[13].v); + rhov[6].v = _mm512_add_pd(rhov[6].v, rhov[14].v); + rhov[7].v = _mm512_add_pd(rhov[7].v, rhov[15].v); + + // 256-bit registers declared to extract 256-bit lanes + v4df_t reduce_sum[16]; + + // reduce_sum[0] = (R0 + R1) (I0 + I1) (R1 + R0) (I1 + I0) + reduce_sum[0].v = _mm512_extractf64x4_pd(rhov[0].v, 0x00); + reduce_sum[1].v = _mm512_extractf64x4_pd(rhov[1].v, 0x00); + reduce_sum[2].v = _mm512_extractf64x4_pd(rhov[2].v, 0x00); + reduce_sum[3].v = _mm512_extractf64x4_pd(rhov[3].v, 0x00); + reduce_sum[4].v = _mm512_extractf64x4_pd(rhov[4].v, 0x00); + reduce_sum[5].v = _mm512_extractf64x4_pd(rhov[5].v, 0x00); + reduce_sum[6].v = _mm512_extractf64x4_pd(rhov[6].v, 0x00); + reduce_sum[7].v = _mm512_extractf64x4_pd(rhov[7].v, 0x00); + + // reduce_sum[8] = (R2 + R3) (I2 + I3) (R3 + R2) (I3 + I2) + reduce_sum[8].v = _mm512_extractf64x4_pd(rhov[0].v, 0x1); + reduce_sum[9].v = _mm512_extractf64x4_pd(rhov[1].v, 0x1); + reduce_sum[10].v = _mm512_extractf64x4_pd(rhov[2].v, 0x1); + reduce_sum[11].v = _mm512_extractf64x4_pd(rhov[3].v, 0x1); + reduce_sum[12].v = _mm512_extractf64x4_pd(rhov[4].v, 0x1); + reduce_sum[13].v = _mm512_extractf64x4_pd(rhov[5].v, 0x1); + reduce_sum[14].v = _mm512_extractf64x4_pd(rhov[6].v, 0x1); + reduce_sum[15].v = _mm512_extractf64x4_pd(rhov[7].v, 0x1); + + // reduce_sum[0] = (R0 + R1 + R2 + R3) (I0 + I1 + I2 + I3) ... + reduce_sum[0].v = _mm256_add_pd(reduce_sum[0].v, reduce_sum[8].v); + reduce_sum[1].v = _mm256_add_pd(reduce_sum[1].v, reduce_sum[9].v); + reduce_sum[2].v = _mm256_add_pd(reduce_sum[2].v, reduce_sum[10].v); + reduce_sum[3].v = _mm256_add_pd(reduce_sum[3].v, reduce_sum[11].v); + reduce_sum[4].v = _mm256_add_pd(reduce_sum[4].v, reduce_sum[12].v); + reduce_sum[5].v = _mm256_add_pd(reduce_sum[5].v, reduce_sum[13].v); + reduce_sum[6].v = _mm256_add_pd(reduce_sum[6].v, reduce_sum[14].v); + reduce_sum[7].v = _mm256_add_pd(reduce_sum[7].v, reduce_sum[15].v); + + // The next set of shuffles, permutes and inserts are performed to store + // all the dot-products onto two 512 registers. They are used to perform + // aligned stores onto the stack memory. + reduce_sum[8].v = _mm256_shuffle_pd(reduce_sum[0].v, reduce_sum[1].v, 0xC); + reduce_sum[9].v = _mm256_shuffle_pd(reduce_sum[2].v, reduce_sum[3].v, 0xC); + reduce_sum[10].v = _mm256_shuffle_pd(reduce_sum[4].v, reduce_sum[5].v, 0xC); + reduce_sum[11].v = _mm256_shuffle_pd(reduce_sum[6].v, reduce_sum[7].v, 0xC); + + reduce_sum[12].v = _mm256_permutex_pd(reduce_sum[8].v, 0xD8); + reduce_sum[13].v = _mm256_permutex_pd(reduce_sum[9].v, 0xD8); + reduce_sum[14].v = _mm256_permutex_pd(reduce_sum[10].v, 0xD8); + reduce_sum[15].v = _mm256_permutex_pd(reduce_sum[11].v, 0xD8); + + rhov[0].v = _mm512_insertf64x4(rhov[0].v, reduce_sum[12].v, 0x00); + rhov[0].v = _mm512_insertf64x4(rhov[0].v, reduce_sum[13].v, 0x01); + rhov[1].v = _mm512_insertf64x4(rhov[1].v, reduce_sum[14].v, 0x00); + rhov[1].v = _mm512_insertf64x4(rhov[1].v, reduce_sum[15].v, 0x01); + + // Negate the sign bit of imaginary part of dot-products if conjat is conjugate + if ( bli_is_conj( conjat ) ) + { + rhov[0].v = _mm512_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + rhov[1].v = _mm512_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[1].v); + } + + /* + Computed dot product result is being stored + in temp buffer r for further computation. + */ + _mm512_store_pd((double *)res, rhov[0].v); + _mm512_store_pd((double *)(res + 4), rhov[1].v); + } + + // This section will have the whole of compute when incx != 1 || inca != 1 + else + { + // Declaring 128-bit registers, for element by element computation + v2df_t rhov[16], a_vec[8], xv[2]; + + // Clearing the partial-sum accumulators + rhov[0].v = _mm_setzero_pd(); + rhov[1].v = _mm_setzero_pd(); + rhov[2].v = _mm_setzero_pd(); + rhov[3].v = _mm_setzero_pd(); + rhov[4].v = _mm_setzero_pd(); + rhov[5].v = _mm_setzero_pd(); + rhov[6].v = _mm_setzero_pd(); + rhov[7].v = _mm_setzero_pd(); + rhov[8].v = _mm_setzero_pd(); + rhov[9].v = _mm_setzero_pd(); + rhov[10].v = _mm_setzero_pd(); + rhov[11].v = _mm_setzero_pd(); + rhov[12].v = _mm_setzero_pd(); + rhov[13].v = _mm_setzero_pd(); + rhov[14].v = _mm_setzero_pd(); + rhov[15].v = _mm_setzero_pd(); + + for (dim_t i = 0; i < m; i++) + { + // Load from X + xv[0].v = _mm_loadu_pd(x_temp); + + // Permute to duplicate the imag part for every element + xv[1].v = _mm_permute_pd(xv[0].v, 0b11); + + // Permute to duplicate the real part for every element + xv[0].v = _mm_permute_pd(xv[0].v, 0b00); + + // Load elements from first 4 columns of A + a_vec[0].v = _mm_loadu_pd(av[0]); + a_vec[1].v = _mm_loadu_pd(av[1]); + a_vec[2].v = _mm_loadu_pd(av[2]); + a_vec[3].v = _mm_loadu_pd(av[3]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[0].v = _mm_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[8].v = _mm_fmadd_pd(a_vec[0].v, xv[1].v, rhov[8].v); + + rhov[1].v = _mm_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[9].v = _mm_fmadd_pd(a_vec[1].v, xv[1].v, rhov[9].v); + + rhov[2].v = _mm_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[10].v = _mm_fmadd_pd(a_vec[2].v, xv[1].v, rhov[10].v); + + rhov[3].v = _mm_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[11].v = _mm_fmadd_pd(a_vec[3].v, xv[1].v, rhov[11].v); + + // Load elements from next 4 columns of A + a_vec[4].v = _mm_loadu_pd(av[4]); + a_vec[5].v = _mm_loadu_pd(av[5]); + a_vec[6].v = _mm_loadu_pd(av[6]); + a_vec[7].v = _mm_loadu_pd(av[7]); + + // Perform: rhov[i].v += a_vec[i].v * xv[0]; + // rhov[i + 8].v += a_vec[i].v * xv[1]; + // This stores the partial sums due to real and + // imag components separately + rhov[4].v = _mm_fmadd_pd(a_vec[4].v, xv[0].v, rhov[4].v); + rhov[12].v = _mm_fmadd_pd(a_vec[4].v, xv[1].v, rhov[12].v); + + rhov[5].v = _mm_fmadd_pd(a_vec[5].v, xv[0].v, rhov[5].v); + rhov[13].v = _mm_fmadd_pd(a_vec[5].v, xv[1].v, rhov[13].v); + + rhov[6].v = _mm_fmadd_pd(a_vec[6].v, xv[0].v, rhov[6].v); + rhov[14].v = _mm_fmadd_pd(a_vec[6].v, xv[1].v, rhov[14].v); + + rhov[7].v = _mm_fmadd_pd(a_vec[7].v, xv[0].v, rhov[7].v); + rhov[15].v = _mm_fmadd_pd(a_vec[7].v, xv[1].v, rhov[15].v); + + // Adjust the pointers accordingly + av[0] += 2 * inca; + av[1] += 2 * inca; + av[2] += 2 * inca; + av[3] += 2 * inca; + av[4] += 2 * inca; + av[5] += 2 * inca; + av[6] += 2 * inca; + av[7] += 2 * inca; + + x_temp += 2 * incx; + } + + // Permuting to help with final reduction + rhov[8].v = _mm_permute_pd(rhov[8].v, 0b01); + rhov[9].v = _mm_permute_pd(rhov[9].v, 0b01); + rhov[10].v = _mm_permute_pd(rhov[10].v, 0b01); + rhov[11].v = _mm_permute_pd(rhov[11].v, 0b01); + rhov[12].v = _mm_permute_pd(rhov[12].v, 0b01); + rhov[13].v = _mm_permute_pd(rhov[13].v, 0b01); + rhov[14].v = _mm_permute_pd(rhov[14].v, 0b01); + rhov[15].v = _mm_permute_pd(rhov[15].v, 0b01); + + v2df_t zero_reg, scale_one; + + zero_reg.v = _mm_setzero_pd(); + scale_one.v = _mm_set1_pd(1.0); + + // Reduction based on conj_op + if ( bli_is_noconj( conj_op ) ) + { + rhov[0].v = _mm_addsub_pd(rhov[0].v, rhov[8].v); + rhov[1].v = _mm_addsub_pd(rhov[1].v, rhov[9].v); + rhov[2].v = _mm_addsub_pd(rhov[2].v, rhov[10].v); + rhov[3].v = _mm_addsub_pd(rhov[3].v, rhov[11].v); + rhov[4].v = _mm_addsub_pd(rhov[4].v, rhov[12].v); + rhov[5].v = _mm_addsub_pd(rhov[5].v, rhov[13].v); + rhov[6].v = _mm_addsub_pd(rhov[6].v, rhov[14].v); + rhov[7].v = _mm_addsub_pd(rhov[7].v, rhov[15].v); + } + else + { + rhov[0].v = _mm_fmsubadd_pd(scale_one.v, rhov[0].v, rhov[8].v); + rhov[1].v = _mm_fmsubadd_pd(scale_one.v, rhov[1].v, rhov[9].v); + rhov[2].v = _mm_fmsubadd_pd(scale_one.v, rhov[2].v, rhov[10].v); + rhov[3].v = _mm_fmsubadd_pd(scale_one.v, rhov[3].v, rhov[11].v); + rhov[4].v = _mm_fmsubadd_pd(scale_one.v, rhov[4].v, rhov[12].v); + rhov[5].v = _mm_fmsubadd_pd(scale_one.v, rhov[5].v, rhov[13].v); + rhov[6].v = _mm_fmsubadd_pd(scale_one.v, rhov[6].v, rhov[14].v); + rhov[7].v = _mm_fmsubadd_pd(scale_one.v, rhov[7].v, rhov[15].v); + } + if( bli_is_conj( conjat ) ) + { + rhov[0].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[0].v); + rhov[1].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[1].v); + rhov[2].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[2].v); + rhov[3].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[3].v); + rhov[4].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[4].v); + rhov[5].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[5].v); + rhov[6].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[6].v); + rhov[7].v = _mm_fmsubadd_pd(zero_reg.v, zero_reg.v, rhov[7].v); + } + + // Storing onto stack memory + _mm_storeu_pd((double *)res, rhov[0].v); + _mm_storeu_pd((double *)(res + 1), rhov[1].v); + _mm_storeu_pd((double *)(res + 2), rhov[2].v); + _mm_storeu_pd((double *)(res + 3), rhov[3].v); + _mm_storeu_pd((double *)(res + 4), rhov[4].v); + _mm_storeu_pd((double *)(res + 5), rhov[5].v); + _mm_storeu_pd((double *)(res + 6), rhov[6].v); + _mm_storeu_pd((double *)(res + 7), rhov[7].v); + + } + + // Scaling by alpha + // Registers to load dot-products from res + v8df_t rhov[2], temp[2]; + + rhov[0].v = _mm512_load_pd((double *)res); + rhov[1].v = _mm512_load_pd((double *)(res + 4)); + + if ( !bli_zeq1( *alpha ) ) + { + __m512d alphaRv, alphaIv; + alphaRv = _mm512_set1_pd((*alpha).real); + alphaIv = _mm512_set1_pd((*alpha).imag); + + temp[0].v = _mm512_permute_pd(rhov[0].v, 0x55); + temp[1].v = _mm512_permute_pd(rhov[1].v, 0x55); + + // Scaling with imag part of alpha + temp[0].v = _mm512_mul_pd(temp[0].v, alphaIv); + temp[1].v = _mm512_mul_pd(temp[1].v, alphaIv); + + // Scaling with real part of alpha, and addsub + rhov[0].v = _mm512_fmaddsub_pd(rhov[0].v, alphaRv, temp[0].v); + rhov[1].v = _mm512_fmaddsub_pd(rhov[1].v, alphaRv, temp[1].v); + } + + // When 'beta' is not zero we need to scale 'y' by 'beta' + v8df_t yv[2]; + + yv[0].v = _mm512_setzero_pd(); + yv[1].v = _mm512_setzero_pd(); + + if (!PASTEMAC(z, eq0)(*beta)) + { + __m512d betaRv, betaIv; + + betaRv = _mm512_set1_pd((*beta).real); + betaIv = _mm512_set1_pd((*beta).imag); + + if (incy == 1) + { + yv[0].v = _mm512_loadu_pd((double *)(y)); + yv[1].v = _mm512_loadu_pd((double *)(y + 4)); + } + else + { + /* + This can be done using SSE instructions + but has been kept as scalar code to avoid + mixing SSE with AVX + */ + yv[0].d[0] = (*(y + 0 * incy)).real; + yv[0].d[1] = (*(y + 0 * incy)).imag; + yv[0].d[2] = (*(y + 1 * incy)).real; + yv[0].d[3] = (*(y + 1 * incy)).imag; + yv[0].d[4] = (*(y + 2 * incy)).real; + yv[0].d[5] = (*(y + 2 * incy)).imag; + yv[0].d[6] = (*(y + 3 * incy)).real; + yv[0].d[7] = (*(y + 3 * incy)).imag; + + yv[1].d[0] = (*(y + 4 * incy)).real; + yv[1].d[1] = (*(y + 4 * incy)).imag; + yv[1].d[2] = (*(y + 5 * incy)).real; + yv[1].d[3] = (*(y + 5 * incy)).imag; + yv[1].d[4] = (*(y + 6 * incy)).real; + yv[1].d[5] = (*(y + 6 * incy)).imag; + yv[1].d[6] = (*(y + 7 * incy)).real; + yv[1].d[7] = (*(y + 7 * incy)).imag; + } + + temp[0].v = _mm512_permute_pd(yv[0].v, 0x55); + temp[1].v = _mm512_permute_pd(yv[1].v, 0x55); + + // Scaling with imag part of alpha + temp[0].v = _mm512_mul_pd(temp[0].v, betaIv); + temp[1].v = _mm512_mul_pd(temp[1].v, betaIv); + + // Scaling with real part of alpha, and addsub + yv[0].v = _mm512_fmaddsub_pd(yv[0].v, betaRv, temp[0].v); + yv[1].v = _mm512_fmaddsub_pd(yv[1].v, betaRv, temp[1].v); + } + + // Adding alpha*A*x to beta*Y + yv[0].v = _mm512_add_pd(yv[0].v, rhov[0].v); + yv[1].v = _mm512_add_pd(yv[1].v, rhov[1].v); + + if (incy == 1) + { + _mm512_storeu_pd((double *)y, yv[0].v); + _mm512_storeu_pd((double *)(y + 4), yv[1].v); + } + else + { + (*(y + 0 * incy)).real = yv[0].d[0]; + (*(y + 0 * incy)).imag = yv[0].d[1]; + (*(y + 1 * incy)).real = yv[0].d[2]; + (*(y + 1 * incy)).imag = yv[0].d[3]; + + (*(y + 2 * incy)).real = yv[0].d[4]; + (*(y + 2 * incy)).imag = yv[0].d[5]; + (*(y + 3 * incy)).real = yv[0].d[6]; + (*(y + 3 * incy)).imag = yv[0].d[7]; + + (*(y + 4 * incy)).real = yv[1].d[0]; + (*(y + 4 * incy)).imag = yv[1].d[1]; + (*(y + 5 * incy)).real = yv[1].d[2]; + (*(y + 5 * incy)).imag = yv[1].d[3]; + + (*(y + 6 * incy)).real = yv[1].d[4]; + (*(y + 6 * incy)).imag = yv[1].d[5]; + (*(y + 7 * incy)).real = yv[1].d[6]; + (*(y + 7 * incy)).imag = yv[1].d[7]; + } + +} diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c index c311d4ebf2..da82f0336e 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -122,7 +122,7 @@ void bli_dpackm_zen4_asm_16xk { for ( dim_t k = k0; k != 0; --k ) { - _mm_prefetch( a + (8*lda), _MM_HINT_T0 ); + _mm_prefetch((char const*)(a + (8*lda)), _MM_HINT_T0 ); for ( dim_t i = 0 ; i < 16 ; i++ ) { bli_dcopys( *(a + i), *(p + i) ); } diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c index 4c7151513e..f01ddb322a 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -117,18 +117,16 @@ void bli_dpackm_zen4_asm_24xk const bool gs = ( inca0 != 1 && lda0 != 1 ); - // NOTE: If/when this kernel ever supports scaling by kappa within the - // assembly region, this constraint should be lifted. - const bool unitk = bli_deq1( *kappa ); - double* restrict a_next = a + cdim0; // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && unitk ) + if ( cdim0 == mnr && !gs ) { begin_asm() mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(kappa), r10) // move kappa to r10 + vbroadcastsd(mem(r10), zmm17) // broadcast kappa into zmm17 mov(var(a), rax) // load address of source buffer. mov(var(a), r13) // load address of source buffer. mov(var(inca), r8) // load inca @@ -207,13 +205,21 @@ void bli_dpackm_zen4_asm_24xk SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + vmulpd(zmm0, zmm17, zmm0) // scale by kappa vmovupd(zmm0, mem(rbx, 0*192)) + vmulpd(zmm4, zmm17, zmm4) vmovupd(zmm4, mem(rbx, 1*192)) + vmulpd(zmm2, zmm17, zmm2) vmovupd(zmm2, mem(rbx, 2*192)) + vmulpd(zmm6, zmm17, zmm6) vmovupd(zmm6, mem(rbx, 3*192)) + vmulpd(zmm1, zmm17, zmm1) vmovupd(zmm1, mem(rbx, 4*192)) + vmulpd(zmm5, zmm17, zmm5) vmovupd(zmm5, mem(rbx, 5*192)) + vmulpd(zmm3, zmm17, zmm3) vmovupd(zmm3, mem(rbx, 6*192)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm8, mem(rbx, 7*192)) add(r15, rax) @@ -238,13 +244,21 @@ void bli_dpackm_zen4_asm_24xk SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + vmulpd(zmm0, zmm17, zmm0) // scale by kappa vmovupd(zmm0, mem(rbx, 0*192 + 64)) + vmulpd(zmm4, zmm17, zmm4) vmovupd(zmm4, mem(rbx, 1*192 + 64)) + vmulpd(zmm2, zmm17, zmm2) vmovupd(zmm2, mem(rbx, 2*192 + 64)) + vmulpd(zmm6, zmm17, zmm6) vmovupd(zmm6, mem(rbx, 3*192 + 64)) + vmulpd(zmm1, zmm17, zmm1) vmovupd(zmm1, mem(rbx, 4*192 + 64)) + vmulpd(zmm5, zmm17, zmm5) vmovupd(zmm5, mem(rbx, 5*192 + 64)) + vmulpd(zmm3, zmm17, zmm3) vmovupd(zmm3, mem(rbx, 6*192 + 64)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm8, mem(rbx, 7*192 + 64)) add(r15, rax) @@ -269,13 +283,21 @@ void bli_dpackm_zen4_asm_24xk SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + vmulpd(zmm0, zmm17, zmm0) // scale by kappa vmovupd(zmm0, mem(rbx, 0*192 + 128)) + vmulpd(zmm4, zmm17, zmm4) vmovupd(zmm4, mem(rbx, 1*192 + 128)) + vmulpd(zmm2, zmm17, zmm2) vmovupd(zmm2, mem(rbx, 2*192 + 128)) + vmulpd(zmm6, zmm17, zmm6) vmovupd(zmm6, mem(rbx, 3*192 + 128)) + vmulpd(zmm1, zmm17, zmm1) vmovupd(zmm1, mem(rbx, 4*192 + 128)) + vmulpd(zmm5, zmm17, zmm5) vmovupd(zmm5, mem(rbx, 5*192 + 128)) + vmulpd(zmm3, zmm17, zmm3) vmovupd(zmm3, mem(rbx, 6*192 + 128)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm8, mem(rbx, 7*192 + 128)) add(imm(8*8), r13) @@ -295,13 +317,21 @@ void bli_dpackm_zen4_asm_24xk label(.DKLEFTROWU) // EDGE LOOP (k_left) vmovupd(mem(rax, 0), zmm6 MASK_KZ(2)) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2)) + vmulpd(zmm12, zmm17, zmm12) vmovupd(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2)) + vmulpd(zmm14, zmm17, zmm14) vmovupd(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2)) + vmulpd(zmm16, zmm17, zmm16) vmovupd(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2)) + vmulpd(zmm18, zmm17, zmm18) vmovupd(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2)) + vmulpd(zmm20, zmm17, zmm20) UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) @@ -387,13 +417,21 @@ void bli_dpackm_zen4_asm_24xk LABEL(.UPDATEDONE) vmovupd(mem(rax, 0), zmm6 MASK_KZ(2)) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2)) + vmulpd(zmm12, zmm17, zmm12) vmovupd(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2)) + vmulpd(zmm14, zmm17, zmm14) vmovupd(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2)) + vmulpd(zmm16, zmm17, zmm16) vmovupd(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2)) + vmulpd(zmm18, zmm17, zmm18) vmovupd(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2)) + vmulpd(zmm20, zmm17, zmm20) UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) @@ -480,13 +518,21 @@ void bli_dpackm_zen4_asm_24xk LABEL(.UPDATEDONEL2) vmovupd(mem(rax, 0), zmm6 MASK_KZ(2)) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2)) + vmulpd(zmm8, zmm17, zmm8) vmovupd(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2)) + vmulpd(zmm12, zmm17, zmm12) vmovupd(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2)) + vmulpd(zmm14, zmm17, zmm14) vmovupd(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2)) + vmulpd(zmm16, zmm17, zmm16) vmovupd(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2)) + vmulpd(zmm18, zmm17, zmm18) vmovupd(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2)) + vmulpd(zmm20, zmm17, zmm20) UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) @@ -608,80 +654,104 @@ void bli_dpackm_zen4_asm_24xk * where i is updated by 1 and rax and rbx updated by lda and ldp. */ vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) add(r8, rbx) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) @@ -699,10 +769,13 @@ void bli_dpackm_zen4_asm_24xk label(.DKLEFTCOLU) // EDGE LOOP (k_left) vmovupd(mem(rax, 0), zmm6) + vmulpd(zmm6, zmm17, zmm6) // scale by kappa vmovupd(mem(rax, 64), zmm8) vmovupd(mem(rax, 128), zmm10) + vmulpd(zmm8, zmm17, zmm8) vmovupd(zmm6, mem(rbx, 0*64+ 0)) vmovupd(zmm8, mem(rbx, 0*64+ 64)) + vmulpd(zmm10, zmm17, zmm10) vmovupd(zmm10, mem(rbx, 0*64+ 128)) add(r10, rax) @@ -723,6 +796,7 @@ void bli_dpackm_zen4_asm_24xk [lda] "m" (lda), [p] "m" (p), [ldp] "m" (ldp), + [kappa] "m" (kappa), [a_next] "m" (a_next) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", @@ -731,7 +805,7 @@ void bli_dpackm_zen4_asm_24xk "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", - "zmm16", "zmm18", "zmm20", "zmm30", "zmm31", "k2", "memory" + "zmm16", "zmm17", "zmm18", "zmm20", "zmm30", "zmm31", "k2", "memory" ) } else // if ( cdim0 < mnr || gs || !unitk ) diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c index 60df4bca4e..8b50d52af4 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -125,7 +125,7 @@ void bli_dpackm_zen4_asm_32xk { for ( dim_t k = k0; k != 0; --k ) { - _mm_prefetch( a + (8*lda), _MM_HINT_T0 ); + _mm_prefetch((char const*)(a + (8*lda)), _MM_HINT_T0 ); for ( dim_t i = 0 ; i < 32 ; i++ ) { bli_dcopys( *(a + i), *(pi1 + i) ); } diff --git a/kernels/zen4/3/CMakeLists.txt b/kernels/zen4/3/CMakeLists.txt deleted file mode 100644 index 6573f85ed8..0000000000 --- a/kernels/zen4/3/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -##Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen4_3 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_zen_16x14.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_zen_16x14.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_zen4_8x24.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_zen4_8x24.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_32x6.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_8x24.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small_AVX512.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_12x4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zero_zmm.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_4x12.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemmtrsm_l_4x12.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemmtrsm_u_4x12.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx512_k1.c - ) - -target_compile_options(zen4_3 PRIVATE /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen4_3 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() - -add_subdirectory(sup) diff --git a/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c b/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c index cab5ea0ce5..8546f9a09a 100644 --- a/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c +++ b/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c b/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c index 887f27889c..0a5c6487ee 100644 --- a/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c +++ b/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -155,15 +155,16 @@ VMULPD(ZMM(R5), ZMM(R5), ZMM(0)) \ VMULPD(ZMM(R6), ZMM(R6), ZMM(0)) \ VMULPD(ZMM(R7), ZMM(R7), ZMM(0)) \ - VMOVUPD(MEM(RCX, 0*8*8), ZMM(R0)) \ - VMOVUPD(MEM(RCX, 1*8*8), ZMM(R1)) \ - VMOVUPD(MEM(RCX, 2*8*8), ZMM(R2)) \ - VMOVUPD(MEM(RCX, 3*8*8), ZMM(R3)) \ - VMOVUPD(MEM(RCX, 4*8*8), ZMM(R4)) \ - VMOVUPD(MEM(RCX, 5*8*8), ZMM(R5)) \ - VMOVUPD(MEM(RCX, 6*8*8), ZMM(R6)) \ - VMOVUPD(MEM(RCX, 7*8*8), ZMM(R7)) \ - LEA(RCX, MEM(RCX,R10,1)) + /*store c*/ \ + VMOVUPD(MEM(RCX), ZMM(R0)) \ + VMOVUPD(MEM(RCX, R10, 1), ZMM(R1)) /*R10 = cs_c*/ \ + VMOVUPD(MEM(RCX, R10, 2), ZMM(R2)) \ + VMOVUPD(MEM(RCX, R11, 1), ZMM(R3)) /*R11 = 3*cs_c*/\ + VMOVUPD(MEM(RCX, R10, 4), ZMM(R4)) \ + VMOVUPD(MEM(RCX, R12, 1), ZMM(R5)) /*R12 = 5*cs_c*/\ + VMOVUPD(MEM(RCX, R11, 2), ZMM(R6)) \ + VMOVUPD(MEM(RCX, R13, 1), ZMM(R7)) /*R13 = 7*cs_c*/\ + LEA(RCX, MEM(RCX,R10,8)) #define SUBITER(n) \ \ diff --git a/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c b/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c index d5a10aa209..9eb7b1594c 100644 --- a/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c +++ b/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c b/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c index e9dae78ba7..714e97064e 100644 --- a/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c +++ b/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 3d10c3a9e4..a76215a076 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -1,19 +1,23 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -25,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" #include "bli_trsm_small_ref.h" @@ -152,7 +157,7 @@ typedef err_t (*trsmsmall_ker_ft) Pack a block of 8xk from input buffer into packed buffer directly or after transpose based on input params */ -BLIS_INLINE void bli_dtrsm_small_pack_avx512 +void bli_dtrsm_small_pack_avx512 ( char side, dim_t size, @@ -406,7 +411,7 @@ BLIS_INLINE void bli_dtrsm_small_pack_avx512 a. This helps in utilze cache line efficiently in TRSM operation b. store ones when input is unit diagonal */ -BLIS_INLINE void dtrsm_small_pack_diag_element_avx512 +void dtrsm_small_pack_diag_element_avx512 ( bool is_unitdiag, double* a11, @@ -486,14 +491,14 @@ trsmsmall_ker_ft ker_fps_AVX512[4][8] = bli_dtrsm_small_XAltB_XAuB_AVX512, bli_dtrsm_small_XAltB_XAuB_AVX512, bli_dtrsm_small_XAutB_XAlB_AVX512}, - {NULL, - NULL, - NULL, - NULL, - NULL, - NULL, - NULL, - NULL}, + {bli_ztrsm_small_AutXB_AlXB_AVX512, + bli_ztrsm_small_AltXB_AuXB_AVX512, + bli_ztrsm_small_AltXB_AuXB_AVX512, + bli_ztrsm_small_AutXB_AlXB_AVX512, + bli_ztrsm_small_XAutB_XAlB_AVX512, + bli_ztrsm_small_XAltB_XAuB_AVX512, + bli_ztrsm_small_XAltB_XAuB_AVX512, + bli_ztrsm_small_XAutB_XAlB_AVX512}, }; /* * The bli_trsm_small implements a version of TRSM where A is packed and reused @@ -526,12 +531,12 @@ err_t bli_trsm_small_AVX512 switch (dt) { case BLIS_DOUBLE: + case BLIS_DCOMPLEX: { break; } case BLIS_FLOAT: case BLIS_SCOMPLEX: - case BLIS_DCOMPLEX: default: { return BLIS_NOT_YET_IMPLEMENTED; @@ -602,6 +607,11 @@ err_t bli_trsm_small_mt_AVX512 d_mr = 8, d_nr = 8; break; } + case BLIS_DCOMPLEX: + { + d_mr = 4, d_nr = 4; + break; + } default: { return BLIS_NOT_YET_IMPLEMENTED; @@ -616,7 +626,7 @@ err_t bli_trsm_small_mt_AVX512 // If dynamic-threading is enabled, calculate optimum number // of threads. // rntm will be updated with optimum number of threads. - if (bli_obj_is_double(b)) + if (bli_obj_is_double(b) || bli_obj_is_dcomplex(b) ) { bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); } @@ -750,7 +760,7 @@ err_t bli_trsm_small_mt_AVX512 zmm8 = _mm512_set1_pd(*(a01 + (p_lda * 7))); \ \ /*prefetch b10 4 iterations in advance*/ \ - _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4 * cs_b), _MM_HINT_T0); \ zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ @@ -774,7 +784,7 @@ err_t bli_trsm_small_mt_AVX512 zmm21 = _mm512_set1_pd(*(a01_2 + (p_lda * 4))); \ zmm22 = _mm512_set1_pd(*(a01_2 + (p_lda * 5))); \ \ - _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10_2 + 4 * cs_b), _MM_HINT_T0); \ zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ zmm17 = _mm512_set1_pd(*(a01_2 + (p_lda * 6))); \ zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ @@ -791,22 +801,22 @@ err_t bli_trsm_small_mt_AVX512 } \ \ /*prefetch 8 columns of b11)*/ \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ /*combine the results of both loops*/ \ zmm9 = _mm512_add_pd(zmm9, zmm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ zmm10 = _mm512_add_pd(zmm10, zmm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ zmm11 = _mm512_add_pd(zmm11, zmm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ zmm12 = _mm512_add_pd(zmm12, zmm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ zmm13 = _mm512_add_pd(zmm13, zmm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ zmm14 = _mm512_add_pd(zmm14, zmm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ zmm15 = _mm512_add_pd(zmm15, zmm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ zmm16 = _mm512_add_pd(zmm16, zmm31); /* // alternative way to prrefetch b11 @@ -822,8 +832,8 @@ err_t bli_trsm_small_mt_AVX512 // zmm21 = _mm512_set1_pd(*(a01_2 + p_lda * 4)); \ // zmm22 = _mm512_set1_pd(*(a01_2 + p_lda * 5)); \ // \ -// _mm_prefetch((b10_2 + 4*cs_b), _MM_HINT_T0); \ -// _mm_prefetch((b11 + (itr2-1)*cs_b), _MM_HINT_T0); \ +// _mm_prefetch((char const*)(b10_2 + 4*cs_b), _MM_HINT_T0); \ +// _mm_prefetch((char const*)(b11 + (itr2-1)*cs_b), _MM_HINT_T0); \ // zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ // zmm17 = _mm512_set1_pd(*(a01_2 + p_lda * 6)); \ // zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ @@ -856,7 +866,7 @@ err_t bli_trsm_small_mt_AVX512 zmm7 = _mm512_set1_pd(*(a01 + p_lda * 6)); \ zmm8 = _mm512_set1_pd(*(a01 + p_lda * 7)); \ \ - _mm_prefetch((b10 + 4*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4*cs_b), _MM_HINT_T0); \ zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ @@ -883,8 +893,8 @@ err_t bli_trsm_small_mt_AVX512 zmm7 = _mm512_set1_pd(*(a01 + p_lda * 6)); \ zmm8 = _mm512_set1_pd(*(a01 + p_lda * 7)); \ \ - _mm_prefetch((b10 + 4*cs_b), _MM_HINT_T0); \ - _mm_prefetch((b11 + (itr-1)*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4*cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (itr-1)*cs_b), _MM_HINT_T0); \ zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ @@ -920,7 +930,7 @@ err_t bli_trsm_small_mt_AVX512 ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ \ - _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4 * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ @@ -944,7 +954,7 @@ err_t bli_trsm_small_mt_AVX512 ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ \ - _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10_2 + 4 * cs_b), _MM_HINT_T0); \ ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ @@ -960,21 +970,21 @@ err_t bli_trsm_small_mt_AVX512 b10_2 += cs_b; \ } \ /*combine the results of both loops*/ \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_add_pd(ymm9, ymm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ ymm10 = _mm256_add_pd(ymm10, ymm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ ymm11 = _mm256_add_pd(ymm11, ymm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ ymm12 = _mm256_add_pd(ymm12, ymm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ ymm13 = _mm256_add_pd(ymm13, ymm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ ymm14 = _mm256_add_pd(ymm14, ymm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ ymm15 = _mm256_add_pd(ymm15, ymm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ ymm16 = _mm256_add_pd(ymm16, ymm31); @@ -1002,7 +1012,7 @@ err_t bli_trsm_small_mt_AVX512 ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ \ - _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4 * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ @@ -1028,7 +1038,7 @@ err_t bli_trsm_small_mt_AVX512 ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ \ - _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10_2 + 4 * cs_b), _MM_HINT_T0); \ ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ @@ -1044,21 +1054,21 @@ err_t bli_trsm_small_mt_AVX512 b10_2 += cs_b; \ } \ /*combine the results of both loops*/ \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_add_pd(ymm9, ymm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ ymm10 = _mm256_add_pd(ymm10, ymm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ ymm11 = _mm256_add_pd(ymm11, ymm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ ymm12 = _mm256_add_pd(ymm12, ymm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ ymm13 = _mm256_add_pd(ymm13, ymm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ ymm14 = _mm256_add_pd(ymm14, ymm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ ymm15 = _mm256_add_pd(ymm15, ymm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ ymm16 = _mm256_add_pd(ymm16, ymm31); #define BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ @@ -1083,7 +1093,7 @@ err_t bli_trsm_small_mt_AVX512 ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ \ - _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4 * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ @@ -1108,7 +1118,7 @@ err_t bli_trsm_small_mt_AVX512 ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ \ - _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10_2 + 4 * cs_b), _MM_HINT_T0); \ ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ @@ -1124,21 +1134,21 @@ err_t bli_trsm_small_mt_AVX512 b10_2 += cs_b; \ } \ /*combine the results of both loops*/ \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_add_pd(ymm9, ymm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ ymm10 = _mm256_add_pd(ymm10, ymm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ ymm11 = _mm256_add_pd(ymm11, ymm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ ymm12 = _mm256_add_pd(ymm12, ymm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ ymm13 = _mm256_add_pd(ymm13, ymm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ ymm14 = _mm256_add_pd(ymm14, ymm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ ymm15 = _mm256_add_pd(ymm15, ymm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ ymm16 = _mm256_add_pd(ymm16, ymm31); #define BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ @@ -1162,7 +1172,7 @@ err_t bli_trsm_small_mt_AVX512 ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ \ - _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10 + 4 * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ @@ -1186,7 +1196,7 @@ err_t bli_trsm_small_mt_AVX512 ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ \ - _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b10_2 + 4 * cs_b), _MM_HINT_T0); \ ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ @@ -1202,21 +1212,21 @@ err_t bli_trsm_small_mt_AVX512 b10_2 += cs_b; \ } \ /*combine the results of both loops*/ \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_add_pd(ymm9, ymm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ ymm10 = _mm256_add_pd(ymm10, ymm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ ymm11 = _mm256_add_pd(ymm11, ymm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ ymm12 = _mm256_add_pd(ymm12, ymm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ ymm13 = _mm256_add_pd(ymm13, ymm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ ymm14 = _mm256_add_pd(ymm14, ymm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ ymm15 = _mm256_add_pd(ymm15, ymm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ ymm16 = _mm256_add_pd(ymm16, ymm31); @@ -1984,7 +1994,7 @@ err_t bli_trsm_small_mt_AVX512 // endregion - pre/post DTRSM macros for right variants // RUNN - RLTN -BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 +err_t bli_dtrsm_small_XAltB_XAuB_AVX512 ( obj_t* AlphaObj, obj_t* a, @@ -4314,7 +4324,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 // RLNN - RUTN -BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 +err_t bli_dtrsm_small_XAutB_XAlB_AVX512 ( obj_t* AlphaObj, obj_t* a, @@ -6545,7 +6555,10 @@ else if ( n_remainder == 2) ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); + ymm0 = _mm256_broadcast_sd((double const *)b11 + 2); + xmm5 = _mm_loadu_pd((double *)(b11)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); BLIS_POST_DTRSM_SMALL_1N_3M(b11, cs_b) @@ -6853,7 +6866,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] zmm7 = _mm512_set1_pd(*(b01 + cs_b * 6)); \ zmm8 = _mm512_set1_pd(*(b01 + cs_b * 7)); \ \ - _mm_prefetch((b01 + 8), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b01 + 8), _MM_HINT_T0); \ zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9); \ zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ @@ -6877,7 +6890,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] zmm21 = _mm512_set1_pd(*(b01_2 + cs_b * 4)); \ zmm22 = _mm512_set1_pd(*(b01_2 + cs_b * 5)); \ \ - _mm_prefetch((b01_2 + 8), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b01_2 + 8), _MM_HINT_T0); \ zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ zmm17 = _mm512_set1_pd(*(b01_2 + cs_b * 6)); \ zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ @@ -6892,21 +6905,21 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] b01_2 += 1; \ a10_2 += p_lda; \ } \ - _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (0) * cs_b), _MM_HINT_T0); \ zmm9 = _mm512_add_pd(zmm9, zmm24); \ - _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (1) * cs_b), _MM_HINT_T0); \ zmm10 = _mm512_add_pd(zmm10, zmm25); \ - _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (2) * cs_b), _MM_HINT_T0); \ zmm11 = _mm512_add_pd(zmm11, zmm26); \ - _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (3) * cs_b), _MM_HINT_T0); \ zmm12 = _mm512_add_pd(zmm12, zmm27); \ - _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (4) * cs_b), _MM_HINT_T0); \ zmm13 = _mm512_add_pd(zmm13, zmm28); \ - _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (5) * cs_b), _MM_HINT_T0); \ zmm14 = _mm512_add_pd(zmm14, zmm29); \ - _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (6) * cs_b), _MM_HINT_T0); \ zmm15 = _mm512_add_pd(zmm15, zmm30); \ - _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b11 + (7) * cs_b), _MM_HINT_T0); \ zmm16 = _mm512_add_pd(zmm16, zmm31); #define BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) \ @@ -7002,7 +7015,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] ymm7 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 6))); \ ymm8 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 7))); \ \ - _mm_prefetch((b01 + 4 * cs_b), _MM_HINT_T0); \ + _mm_prefetch((char const*)(b01 + 4 * cs_b), _MM_HINT_T0); \ ymm9 = _mm256_fmadd_pd (ymm1, ymm0, ymm9); \ ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ @@ -7229,7 +7242,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); // LLNN - LUTN -BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 +err_t bli_dtrsm_small_AutXB_AlXB_AVX512 ( obj_t* AlphaObj, obj_t* a, @@ -9200,7 +9213,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 // LUNN LUTN -BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 +err_t bli_dtrsm_small_AltXB_AuXB_AVX512 ( obj_t* AlphaObj, obj_t* a, diff --git a/kernels/zen4/3/bli_zero_zmm.c b/kernels/zen4/3/bli_zero_zmm.c index 67ff9a62de..78b22e194c 100644 --- a/kernels/zen4/3/bli_zero_zmm.c +++ b/kernels/zen4/3/bli_zero_zmm.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_zgemm_avx512_k1.c b/kernels/zen4/3/bli_zgemm_avx512_k1.c new file mode 100644 index 0000000000..a1f3fbd296 --- /dev/null +++ b/kernels/zen4/3/bli_zgemm_avx512_k1.c @@ -0,0 +1,1993 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "immintrin.h" + +#define Z_MR 16 +#define Z_NR 4 + +/* + The following API implements the ZGEMM operation specifically for + inputs A and B with k == 1. It expects the inputs and output to + support the column-major storage scheme, without any requirement + to conjugate/transpose any of the operands. + + Design details : + Kernel dimensions - 16 x 4 + Loop ordering - N-loop, followed by M-loop + + The N-Loop will scale B by alpha and presave them on registers + for its reuse in M-Loop. Thus is blocks 2 * 4(broadcast) registers, + due to separate real and imaginary components + + Thus the register blocking for the hotspot code-section is as follows : + Loading A - 4 + Permuting A - 4 + alpha * B presave - 8 + Accumulating C - 16 + + Total - 32 + + Any other register used for miscellaneous computation will not induce + register dependency explicitly. +*/ + +err_t bli_zgemm_16x4_avx512_k1_nn + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ) +{ + // Setting the required variables to choose the right + // path for computation. + dim_t m_iter = ( m / Z_MR ); + dim_t n_iter = ( n / Z_NR ); + + dim_t m_remainder = ( m % Z_MR ); + dim_t n_remainder = ( n % Z_NR ); + + // Setting the alpha and beta scaling components(real and imaginary). + double alpha_real = alpha->real; + double alpha_imag = alpha->imag; + + double beta_real = beta->real; + double beta_imag = beta->imag; + + // Using the predefined enumerated constants to classify beta scaling + // into one of the below categories. + dim_t beta_mul_type = BLIS_MUL_DEFAULT; + + // Setting the appropriate type for beta scaling + // based on any of the special cases. + if( beta_imag == 0.0 ) + { + if( beta_real == 0.0 ) beta_mul_type = BLIS_MUL_ZERO; + else if( beta_real == 1.0 ) beta_mul_type = BLIS_MUL_ONE; + } + + // Implementing the GEMM operation, which is as follows : + // C := beta*C + alpha*A*B. + + // Local pointers for B and C, to be used along the n-loop + dcomplex* temp_b = b; + dcomplex* temp_c = c; + + // Main loop along N dimension + for( dim_t j = 0; j < n_iter; j++ ) + { + dcomplex* temp_ai = a; + dcomplex* temp_bj = temp_b; + dcomplex* temp_cij = temp_c; + + /* + Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x Z_NR block of B in order to compute the associated + Z_MR x Z_NR and/or m_remainder x Z_NR block(s) of C. Due to this, the + associated 1 x Z_NR block of B is scaled with alpha, and stored in registers + beforehand, to be reused in the main loop or fringe case of m. + */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m512d a_vec[4], bdcst_real[4], bdcst_imag[4], b_vec[4], temp[4]; + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec[0] = _mm512_set1_pd(alpha_real); + a_vec[1] = _mm512_set1_pd(alpha_imag); + + // Broadcasting real and imag components from B onto separate registers. + // They are then unpacked to get the interleaved storage format on registers. + // bdcst_real[0] = R0 R0 R0 R0 ... + bdcst_real[0] = _mm512_set1_pd(*((double *)(temp_bj))); + // bdcst_imag[0] = I0 I0 I0 I0 ... + bdcst_imag[0] = _mm512_set1_pd(*((double *)(temp_bj) + 1)); + // b_vec[0] = R0 I0 R0 I0 ... + b_vec[0] = _mm512_unpacklo_pd(bdcst_real[0], bdcst_imag[0]); + // temp[0] = I0 R0 I0 R0 ... + temp[0] = _mm512_unpacklo_pd(bdcst_imag[0], bdcst_real[0]); + + // bdcst_real[1] = R1 R1 R1 R1 ... + bdcst_real[1] = _mm512_set1_pd(*((double *)(temp_bj + ldb))); + // bdcst_imag[1] = I1 I1 I1 I1 ... + bdcst_imag[1] = _mm512_set1_pd(*((double *)(temp_bj + ldb) + 1)); + // b_vec[1] = R1 I1 R1 I1 ... + b_vec[1] = _mm512_unpacklo_pd(bdcst_real[1], bdcst_imag[1]); + // temp[1] = I1 R1 I1 R1 ... + temp[1] = _mm512_unpacklo_pd(bdcst_imag[1], bdcst_real[1]); + + // Scaling with imag component of alpha + temp[0] = _mm512_mul_pd(a_vec[1], temp[0]); + temp[1] = _mm512_mul_pd(a_vec[1], temp[1]); + // Scaling with real component of alpha and accumulating + b_vec[0] = _mm512_fmaddsub_pd(a_vec[0], b_vec[0], temp[0]); + b_vec[1] = _mm512_fmaddsub_pd(a_vec[0], b_vec[1], temp[1]); + + // Continuing the same set of instructions, to load B, unpack + // them, scale with alpha and store on registers + // bdcst_real[2] = R2 R2 R2 R2 ... + bdcst_real[2] = _mm512_set1_pd(*((double *)(temp_bj + 2 * ldb))); + // bdcst_imag[2] = I2 I2 I2 I2 ... + bdcst_imag[2] = _mm512_set1_pd(*((double *)(temp_bj + 2 * ldb) + 1)); + // b_vec[2] = R2 I2 R2 I2 ... + b_vec[2] = _mm512_unpacklo_pd(bdcst_real[2], bdcst_imag[2]); + // temp[2] = I2 R2 I2 R2 ... + temp[2] = _mm512_unpacklo_pd(bdcst_imag[2], bdcst_real[2]); + + // bdcst_real[3] = R3 R3 R3 R3 ... + bdcst_real[3] = _mm512_set1_pd(*((double *)(temp_bj + 3 * ldb))); + // bdcst_imag[3] = I3 I3 I3 I3 ... + bdcst_imag[3] = _mm512_set1_pd(*((double *)(temp_bj + 3 * ldb) + 1)); + // b_vec[3] = R3 I3 R3 I3 ... + b_vec[3] = _mm512_unpacklo_pd(bdcst_real[3], bdcst_imag[3]); + // temp[3] = I3 R3 I3 R3 ... + temp[3] = _mm512_unpacklo_pd(bdcst_imag[3], bdcst_real[3]); + + // Scaling with imag component of alpha + temp[2] = _mm512_mul_pd(a_vec[1], temp[2]); + temp[3] = _mm512_mul_pd(a_vec[1], temp[3]); + // Scaling with real component of alpha and accumulating + b_vec[2] = _mm512_fmaddsub_pd(a_vec[0], b_vec[2], temp[2]); + b_vec[3] = _mm512_fmaddsub_pd(a_vec[0], b_vec[3], temp[3]); + + // Registers b_vec[0 ... 3] contain alpha scaled B. These + // are unpacked in order to contain the real and imaginary + // components of each element in separate registers. + bdcst_real[0] = _mm512_unpacklo_pd(b_vec[0], b_vec[0]); + bdcst_real[1] = _mm512_unpacklo_pd(b_vec[1], b_vec[1]); + bdcst_real[2] = _mm512_unpacklo_pd(b_vec[2], b_vec[2]); + bdcst_real[3] = _mm512_unpacklo_pd(b_vec[3], b_vec[3]); + + bdcst_imag[0] = _mm512_unpackhi_pd(b_vec[0], b_vec[0]); + bdcst_imag[1] = _mm512_unpackhi_pd(b_vec[1], b_vec[1]); + bdcst_imag[2] = _mm512_unpackhi_pd(b_vec[2], b_vec[2]); + bdcst_imag[3] = _mm512_unpackhi_pd(b_vec[3], b_vec[3]); + + dim_t i = 0; + dim_t m_rem = m_remainder; + // Main loop along M dimension. + for( ; i < m_iter; i++ ) + { + __m512d a_perm[4], c_vec[16]; + __m512d betaRv, betaIv; + + // Clearing the scratch registers + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + c_vec[4] = _mm512_setzero_pd(); + c_vec[5] = _mm512_setzero_pd(); + c_vec[6] = _mm512_setzero_pd(); + c_vec[7] = _mm512_setzero_pd(); + c_vec[8] = _mm512_setzero_pd(); + c_vec[9] = _mm512_setzero_pd(); + c_vec[10] = _mm512_setzero_pd(); + c_vec[11] = _mm512_setzero_pd(); + c_vec[12] = _mm512_setzero_pd(); + c_vec[13] = _mm512_setzero_pd(); + c_vec[14] = _mm512_setzero_pd(); + c_vec[15] = _mm512_setzero_pd(); + + // Loading 16 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_ai + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_ai + 12)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + c_vec[2] = _mm512_mul_pd(bdcst_imag[0], a_perm[2]); + c_vec[3] = _mm512_mul_pd(bdcst_imag[0], a_perm[3]); + c_vec[4] = _mm512_mul_pd(bdcst_imag[1], a_perm[0]); + c_vec[5] = _mm512_mul_pd(bdcst_imag[1], a_perm[1]); + c_vec[6] = _mm512_mul_pd(bdcst_imag[1], a_perm[2]); + c_vec[7] = _mm512_mul_pd(bdcst_imag[1], a_perm[3]); + + c_vec[8] = _mm512_mul_pd(bdcst_imag[2], a_perm[0]); + c_vec[9] = _mm512_mul_pd(bdcst_imag[2], a_perm[1]); + c_vec[10] = _mm512_mul_pd(bdcst_imag[2], a_perm[2]); + c_vec[11] = _mm512_mul_pd(bdcst_imag[2], a_perm[3]); + c_vec[12] = _mm512_mul_pd(bdcst_imag[3], a_perm[0]); + c_vec[13] = _mm512_mul_pd(bdcst_imag[3], a_perm[1]); + c_vec[14] = _mm512_mul_pd(bdcst_imag[3], a_perm[2]); + c_vec[15] = _mm512_mul_pd(bdcst_imag[3], a_perm[3]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[2], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[3], c_vec[3]); + c_vec[4] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[4]); + c_vec[5] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[1], c_vec[5]); + c_vec[6] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[2], c_vec[6]); + c_vec[7] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[3], c_vec[7]); + + c_vec[8] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[0], c_vec[8]); + c_vec[9] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[1], c_vec[9]); + c_vec[10] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[2], c_vec[10]); + c_vec[11] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[3], c_vec[11]); + c_vec[12] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[0], c_vec[12]); + c_vec[13] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[1], c_vec[13]); + c_vec[14] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[2], c_vec[14]); + c_vec[15] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[3], c_vec[15]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[8]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[9]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 8), c_vec[10]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 12), c_vec[11]); + + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[12]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[13]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 8), c_vec[14]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 12), c_vec[15]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Adding to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[2]); + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 12)); + + // Adding to alpha*A*B + c_vec[4] = _mm512_add_pd(c_vec[4], a_vec[0]); + c_vec[5] = _mm512_add_pd(c_vec[5], a_vec[1]); + c_vec[6] = _mm512_add_pd(c_vec[6], a_vec[2]); + c_vec[7] = _mm512_add_pd(c_vec[7], a_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 12)); + + // Adding to alpha*A*B + c_vec[8] = _mm512_add_pd(c_vec[8], a_vec[0]); + c_vec[9] = _mm512_add_pd(c_vec[9], a_vec[1]); + c_vec[10] = _mm512_add_pd(c_vec[10], a_vec[2]); + c_vec[11] = _mm512_add_pd(c_vec[11], a_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[8]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[9]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 8), c_vec[10]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 12), c_vec[11]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 12)); + + // Adding to alpha*A*B + c_vec[12] = _mm512_add_pd(c_vec[12], a_vec[0]); + c_vec[13] = _mm512_add_pd(c_vec[13], a_vec[1]); + c_vec[14] = _mm512_add_pd(c_vec[14], a_vec[2]); + c_vec[15] = _mm512_add_pd(c_vec[15], a_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[12]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[13]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 8), c_vec[14]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 12), c_vec[15]); + break; + + default : + // Loading the real and imag parts of beta + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + c_vec[2] = _mm512_add_pd(a_vec[2], c_vec[2]); + c_vec[3] = _mm512_add_pd(a_vec[3], c_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[4] = _mm512_add_pd(a_vec[0], c_vec[4]); + c_vec[5] = _mm512_add_pd(a_vec[1], c_vec[5]); + c_vec[6] = _mm512_add_pd(a_vec[2], c_vec[6]); + c_vec[7] = _mm512_add_pd(a_vec[3], c_vec[7]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[8] = _mm512_add_pd(a_vec[0], c_vec[8]); + c_vec[9] = _mm512_add_pd(a_vec[1], c_vec[9]); + c_vec[10] = _mm512_add_pd(a_vec[2], c_vec[10]); + c_vec[11] = _mm512_add_pd(a_vec[3], c_vec[11]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[8]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[9]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 8), c_vec[10]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 12), c_vec[11]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[12] = _mm512_add_pd(a_vec[0], c_vec[12]); + c_vec[13] = _mm512_add_pd(a_vec[1], c_vec[13]); + c_vec[14] = _mm512_add_pd(a_vec[2], c_vec[14]); + c_vec[15] = _mm512_add_pd(a_vec[3], c_vec[15]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[12]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[13]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 8), c_vec[14]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 12), c_vec[15]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 16; + temp_ai += 16; + } + + if( m_rem >= 8 ) + { + __m512d a_perm[2], c_vec[8]; + __m512d betaRv, betaIv; + + // Clearing the scratch registers + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + c_vec[4] = _mm512_setzero_pd(); + c_vec[5] = _mm512_setzero_pd(); + c_vec[6] = _mm512_setzero_pd(); + c_vec[7] = _mm512_setzero_pd(); + + // Loading 8 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + c_vec[2] = _mm512_mul_pd(bdcst_imag[1], a_perm[0]); + c_vec[3] = _mm512_mul_pd(bdcst_imag[1], a_perm[1]); + + c_vec[4] = _mm512_mul_pd(bdcst_imag[2], a_perm[0]); + c_vec[5] = _mm512_mul_pd(bdcst_imag[2], a_perm[1]); + c_vec[6] = _mm512_mul_pd(bdcst_imag[3], a_perm[0]); + c_vec[7] = _mm512_mul_pd(bdcst_imag[3], a_perm[1]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[1], c_vec[3]); + + c_vec[4] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[0], c_vec[4]); + c_vec[5] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[1], c_vec[5]); + c_vec[6] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[0], c_vec[6]); + c_vec[7] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[1], c_vec[7]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[5]); + + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[7]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Adding it to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + + // Adding it to alpha*A*B + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[0]); + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 4)); + + // Adding it to alpha*A*B + c_vec[4] = _mm512_add_pd(c_vec[4], a_vec[0]); + c_vec[5] = _mm512_add_pd(c_vec[5], a_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[5]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 4)); + + // Adding it to alpha*A*B + c_vec[6] = _mm512_add_pd(c_vec[6], a_vec[0]); + c_vec[7] = _mm512_add_pd(c_vec[7], a_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[7]); + break; + + default : + // Loading real and imag components of beta + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[2] = _mm512_add_pd(a_vec[0], c_vec[2]); + c_vec[3] = _mm512_add_pd(a_vec[1], c_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[4] = _mm512_add_pd(a_vec[0], c_vec[4]); + c_vec[5] = _mm512_add_pd(a_vec[1], c_vec[5]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc + 4), c_vec[5]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[6] = _mm512_add_pd(a_vec[0], c_vec[6]); + c_vec[7] = _mm512_add_pd(a_vec[1], c_vec[7]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc + 4), c_vec[7]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 8; + temp_ai += 8; + + m_rem -= 8; + } + + if( m_rem >= 4 ) + { + __m512d a_perm, c_vec[4]; + __m512d betaRv, betaIv; + + // Clearing scratch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + + // Loading 4 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm); + c_vec[1] = _mm512_mul_pd(bdcst_imag[1], a_perm); + + c_vec[2] = _mm512_mul_pd(bdcst_imag[2], a_perm); + c_vec[3] = _mm512_mul_pd(bdcst_imag[3], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[1]); + + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[0], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[0], c_vec[3]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[3]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Adding to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + + // Storing the result onto memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + + // Adding it to alpha*A*B + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + + // Adding it to alpha*A*B + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[2]); + + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + + // Adding it to alpha*A*B + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[3]); + break; + + default : + + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[1] = _mm512_add_pd(a_vec[0], c_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 2 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[2] = _mm512_add_pd(a_vec[0], c_vec[2]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 2 * ldc), c_vec[2]); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 3 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[3] = _mm512_add_pd(a_vec[0], c_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 3 * ldc), c_vec[3]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 4; + temp_ai += 4; + + m_rem -= 4; + } + + if( m_rem > 0 ) + { + // Setting the mask to load/store the reamining elements + // Ex : m_rem = 2 => m_mask = ( 1 << 2 * 2 ) - 1 + // = 0b0010000 - 1 + // = 0b00001111 + // m_rem is multiplied by 2 since it accounts for 2 doubles + __mmask8 m_mask = m_mask = (1 << 2 * m_rem) - 1; + __m512d a_perm, c_vec[4]; + __m512d betaRv, betaIv; + + // Clearing the scratch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + + // Loading the remaining elements from A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm); + c_vec[1] = _mm512_mul_pd(bdcst_imag[1], a_perm); + + c_vec[2] = _mm512_mul_pd(bdcst_imag[2], a_perm); + c_vec[3] = _mm512_mul_pd(bdcst_imag[3], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[1]); + + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[2], a_vec[0], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[3], a_vec[0], c_vec[3]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + _mm512_mask_storeu_pd((double *)(temp_cij + 2 * ldc), m_mask, c_vec[2]); + _mm512_mask_storeu_pd((double *)(temp_cij + 3 * ldc), m_mask, c_vec[3]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Adding it to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + + // Loading C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 1 * ldc)); + + // Adding it to alpha*A*B + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + + // Loading C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 2 * ldc)); + + // Adding it to alpha*A*B + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 2 * ldc), m_mask, c_vec[2]); + + // Loading C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 3 * ldc)); + + // Adding it to alpha*A*B + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 3 * ldc), m_mask, c_vec[3]); + break; + + default : + + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 1 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[1] = _mm512_add_pd(a_vec[0], c_vec[1]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 2 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[2] = _mm512_add_pd(a_vec[0], c_vec[2]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 2 * ldc), m_mask, c_vec[2]); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 3 * ldc)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[3] = _mm512_add_pd(a_vec[0], c_vec[3]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij + 3 * ldc), m_mask, c_vec[3]); + } + } + + // Adjusting the pointers for the next iteration + temp_b += ldb * Z_NR; + temp_c += ldc * Z_NR; + } + + // Fringe case for N + if( n_remainder >= 2 ) + { + dcomplex* temp_ai = a; + dcomplex* temp_bj = temp_b; + dcomplex* temp_cij = temp_c; + + /* Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x 2 block of B in order to compute the associated + Z_MR x 2 and/or m_remainder x 2 block(s) of C. This reusability has been + exploited, wherein the associated 1 x 2 block of B is scaled with alpha, + and stored in registers beforehand, to be reused in the main loop or fringe + case of m. */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m512d a_vec[4], bdcst_real[2], bdcst_imag[2], b_vec[2], temp[2]; + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec[0] = _mm512_set1_pd(alpha_real); + a_vec[1] = _mm512_set1_pd(alpha_imag); + + // Broadcasting real and imag components from B onto separate registers. + // They are then unpacked to get the interleaved storage format on registers. + // bdcst_real[0] = R0 R0 R0 R0 ... + bdcst_real[0] = _mm512_set1_pd(*((double *)(temp_bj))); + // bdcst_imag[0] = I0 I0 I0 I0 ... + bdcst_imag[0] = _mm512_set1_pd(*((double *)(temp_bj) + 1)); + // b_vec[0] = R0 I0 R0 I0 ... + b_vec[0] = _mm512_unpacklo_pd(bdcst_real[0], bdcst_imag[0]); + // temp[0] = I0 R0 I0 R0 ... + temp[0] = _mm512_unpacklo_pd(bdcst_imag[0], bdcst_real[0]); + + // bdcst_real[1] = R1 R1 R1 R1 ... + bdcst_real[1] = _mm512_set1_pd(*((double *)(temp_bj + ldb))); + // bdcst_imag[1] = I1 I1 I1 I1 ... + bdcst_imag[1] = _mm512_set1_pd(*((double *)(temp_bj + ldb) + 1)); + // b_vec[1] = R1 I1 R1 I1 ... + b_vec[1] = _mm512_unpacklo_pd(bdcst_real[1], bdcst_imag[1]); + // temp[1] = I1 R1 I1 R1 ... + temp[1] = _mm512_unpacklo_pd(bdcst_imag[1], bdcst_real[1]); + + // Scaling with imag component of alpha + temp[0] = _mm512_mul_pd(a_vec[1], temp[0]); + temp[1] = _mm512_mul_pd(a_vec[1], temp[1]); + // Scaling with real component of alpha and accumulating + b_vec[0] = _mm512_fmaddsub_pd(a_vec[0], b_vec[0], temp[0]); + b_vec[1] = _mm512_fmaddsub_pd(a_vec[0], b_vec[1], temp[1]); + + // Registers b_vec[0 ... 1] contain alpha scaled B. These + // are unpacked in order to contain the real and imaginary + // components of each element in separate registers. + bdcst_real[0] = _mm512_unpacklo_pd(b_vec[0], b_vec[0]); + bdcst_real[1] = _mm512_unpacklo_pd(b_vec[1], b_vec[1]); + + bdcst_imag[0] = _mm512_unpackhi_pd(b_vec[0], b_vec[0]); + bdcst_imag[1] = _mm512_unpackhi_pd(b_vec[1], b_vec[1]); + + dim_t i = 0; + dim_t m_rem = m_remainder; + // Main loop along M dimension. + for( ; i < m_iter; i++ ) + { + __m512d a_perm[4], c_vec[8]; + __m512d betaRv, betaIv; + + // Clearing the scratch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + c_vec[4] = _mm512_setzero_pd(); + c_vec[5] = _mm512_setzero_pd(); + c_vec[6] = _mm512_setzero_pd(); + c_vec[7] = _mm512_setzero_pd(); + + // Loading 16 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_ai + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_ai + 12)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + c_vec[2] = _mm512_mul_pd(bdcst_imag[0], a_perm[2]); + c_vec[3] = _mm512_mul_pd(bdcst_imag[0], a_perm[3]); + c_vec[4] = _mm512_mul_pd(bdcst_imag[1], a_perm[0]); + c_vec[5] = _mm512_mul_pd(bdcst_imag[1], a_perm[1]); + c_vec[6] = _mm512_mul_pd(bdcst_imag[1], a_perm[2]); + c_vec[7] = _mm512_mul_pd(bdcst_imag[1], a_perm[3]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[2], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[3], c_vec[3]); + c_vec[4] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[4]); + c_vec[5] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[1], c_vec[5]); + c_vec[6] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[2], c_vec[6]); + c_vec[7] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[3], c_vec[7]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Adding C to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[2]); + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 12)); + + c_vec[4] = _mm512_add_pd(c_vec[4], a_vec[0]); + c_vec[5] = _mm512_add_pd(c_vec[5], a_vec[1]); + c_vec[6] = _mm512_add_pd(c_vec[6], a_vec[2]); + c_vec[7] = _mm512_add_pd(c_vec[7], a_vec[3]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + break; + + default : + + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + c_vec[2] = _mm512_add_pd(a_vec[2], c_vec[2]); + c_vec[3] = _mm512_add_pd(a_vec[3], c_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + + // Registers to load beta(real and imag components) + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 12)); + + // Load C from memory + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with imag component of beta + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + // Scaling with real component of beta and accumulating + c_vec[4] = _mm512_add_pd(a_vec[0], c_vec[4]); + c_vec[5] = _mm512_add_pd(a_vec[1], c_vec[5]); + c_vec[6] = _mm512_add_pd(a_vec[2], c_vec[6]); + c_vec[7] = _mm512_add_pd(a_vec[3], c_vec[7]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[4]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[5]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 8), c_vec[6]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 12), c_vec[7]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 16; + temp_ai += 16; + } + + if( m_rem >= 8 ) + { + __m512d a_perm[2], c_vec[4]; + __m512d betaRv, betaIv; + + // Clearing out the scratch registers + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + + // Loading 8 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + c_vec[2] = _mm512_mul_pd(bdcst_imag[1], a_perm[0]); + c_vec[3] = _mm512_mul_pd(bdcst_imag[1], a_perm[1]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[1], c_vec[3]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + break; + + case BLIS_MUL_ONE : + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Add C to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + + // Store the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[0]); + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[1]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + break; + + default : + + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc + 4)); + + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[2] = _mm512_add_pd(a_vec[0], c_vec[2]); + c_vec[3] = _mm512_add_pd(a_vec[1], c_vec[3]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc + 4), c_vec[3]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 8; + temp_ai += 8; + m_rem -= 8; + } + + if( m_rem >= 4 ) + { + __m512d a_perm, c_vec[2]; + __m512d betaRv, betaIv; + + // Clearing out sctarch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + + // Loading 4 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm); + c_vec[1] = _mm512_mul_pd(bdcst_imag[1], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[1]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Adding it to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[0]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + break; + + default : + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij + 1 * ldc)); + + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + a_perm = _mm512_mul_pd(betaIv, a_perm); + + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[1] = _mm512_add_pd(a_vec[0], c_vec[1]); + + _mm512_storeu_pd((double *)(temp_cij + 1 * ldc), c_vec[1]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 4; + temp_ai += 4; + + m_rem -= 4; + } + + if( m_rem > 0 ) + { + // Setting the mask to load/store remaining elements + __mmask8 m_mask = m_mask = (1 << 2 * m_rem) - 1; + __m512d a_perm, c_vec[2]; + __m512d betaRv, betaIv; + + // Clearing out scratch registers + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + + // Loading remaining elements from A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm); + c_vec[1] = _mm512_mul_pd(bdcst_imag[1], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[1], a_vec[0], c_vec[1]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + break; + + case BLIS_MUL_ONE : + // Loading C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Adding it to alpha*A*B + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 1 * ldc)); + + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[0]); + + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + break; + + default : + + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec[0]); + + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij + 1 * ldc)); + + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + a_perm = _mm512_mul_pd(betaIv, a_perm); + + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec[1] = _mm512_add_pd(a_vec[0], c_vec[1]); + + _mm512_mask_storeu_pd((double *)(temp_cij + 1 * ldc), m_mask, c_vec[1]); + } + } + + // Adjusting the pointers accordingly + temp_b += ldb * 2; + temp_c += ldc * 2; + + // Updating n_remainder + n_remainder -= 2; + } + + if( n_remainder == 1 ) + { + dcomplex* temp_ai = a; + dcomplex* temp_bj = temp_b; + dcomplex* temp_cij = temp_c; + + /* + Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x 1 block of B in order to compute the associated + Z_MR x 1 and/or m_remainder x 1 block(s) of C. This reusability has been + exploited, wherein the associated 1 x 1 block of B is scaled with alpha, + and stored in registers beforehand, to be reused in the main loop or fringe + case of m. + */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m512d a_vec[4], bdcst_real[1], bdcst_imag[1], b_vec[1], temp[1]; + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec[0] = _mm512_set1_pd(alpha_real); + a_vec[1] = _mm512_set1_pd(alpha_imag); + + // Broadcasting real and imag components from B onto separate registers. + // They are then unpacked to get the interleaved storage format on registers. + // bdcst_real[0] = R0 R0 R0 R0 ... + bdcst_real[0] = _mm512_set1_pd(*((double *)(temp_bj))); + // bdcst_imag[0] = I0 I0 I0 I0 ... + bdcst_imag[0] = _mm512_set1_pd(*((double *)(temp_bj) + 1)); + // b_vec[0] = R0 I0 R0 I0 ... + b_vec[0] = _mm512_unpacklo_pd(bdcst_real[0], bdcst_imag[0]); + // temp[0] = I0 R0 I0 R0 ... + temp[0] = _mm512_unpacklo_pd(bdcst_imag[0], bdcst_real[0]); + + // Scaling with imag component of alpha + temp[0] = _mm512_mul_pd(a_vec[1], temp[0]); + // Scaling with real component of alpha and accumulating + b_vec[0] = _mm512_fmaddsub_pd(a_vec[0], b_vec[0], temp[0]); + + // Registers b_vec[0] contain alpha scaled B. These + // are unpacked in order to contain the real and imaginary + // components of each element in separate registers. + bdcst_real[0] = _mm512_unpacklo_pd(b_vec[0], b_vec[0]); + + bdcst_imag[0] = _mm512_unpackhi_pd(b_vec[0], b_vec[0]); + + dim_t i = 0; + dim_t m_rem = m_remainder; + // Main loop along M dimension. + for( ; i < m_iter; i++ ) + { + __m512d a_perm[4], c_vec[4]; + __m512d betaRv, betaIv; + + // Clearing scratch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + c_vec[2] = _mm512_setzero_pd(); + c_vec[3] = _mm512_setzero_pd(); + + // Loading 16 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_ai + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_ai + 12)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + c_vec[2] = _mm512_mul_pd(bdcst_imag[0], a_perm[2]); + c_vec[3] = _mm512_mul_pd(bdcst_imag[0], a_perm[3]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + c_vec[2] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[2], c_vec[2]); + c_vec[3] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[3], c_vec[3]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + break; + + case BLIS_MUL_ONE : + // Loading from C + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Adding alpha*A*b to C + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + c_vec[2] = _mm512_add_pd(c_vec[2], a_vec[2]); + c_vec[3] = _mm512_add_pd(c_vec[3], a_vec[3]); + + // Storing to C + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + break; + + default : + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + a_vec[2] = _mm512_loadu_pd((double const*)(temp_cij + 8)); + a_vec[3] = _mm512_loadu_pd((double const*)(temp_cij + 12)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + a_perm[2] = _mm512_permute_pd(a_vec[2], 0x55); + a_perm[3] = _mm512_permute_pd(a_vec[3], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + a_perm[2] = _mm512_mul_pd(betaIv, a_perm[2]); + a_perm[3] = _mm512_mul_pd(betaIv, a_perm[3]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + a_vec[2] = _mm512_fmaddsub_pd(betaRv, a_vec[2], a_perm[2]); + a_vec[3] = _mm512_fmaddsub_pd(betaRv, a_vec[3], a_perm[3]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + c_vec[2] = _mm512_add_pd(a_vec[2], c_vec[2]); + c_vec[3] = _mm512_add_pd(a_vec[3], c_vec[3]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + _mm512_storeu_pd((double *)(temp_cij + 8), c_vec[2]); + _mm512_storeu_pd((double *)(temp_cij + 12), c_vec[3]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 16; + temp_ai += 16; + } + + if( m_rem >= 8 ) + { + __m512d a_perm[2], c_vec[2]; + __m512d betaRv, betaIv; + + // Clearing scratch registers for accumalation + c_vec[0] = _mm512_setzero_pd(); + c_vec[1] = _mm512_setzero_pd(); + + // Loading 8 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_ai + 4)); + + // Swapping real and imag components, to be used in computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag components of alpha*B + c_vec[0] = _mm512_mul_pd(bdcst_imag[0], a_perm[0]); + c_vec[1] = _mm512_mul_pd(bdcst_imag[0], a_perm[1]); + + // Scaling with real comp of alpha*B and accumulating + c_vec[0] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec[0]); + c_vec[1] = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[1], c_vec[1]); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + break; + + case BLIS_MUL_ONE : + // Loading from C + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Adding alpha*A*b to C + c_vec[0] = _mm512_add_pd(c_vec[0], a_vec[0]); + c_vec[1] = _mm512_add_pd(c_vec[1], a_vec[1]); + + // Storing to C + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + break; + + default : + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + a_vec[1] = _mm512_loadu_pd((double const*)(temp_cij + 4)); + + // Swapping real and imag parts of C for computation + a_perm[0] = _mm512_permute_pd(a_vec[0], 0x55); + a_perm[1] = _mm512_permute_pd(a_vec[1], 0x55); + + // Scaling with imag component of beta + a_perm[0] = _mm512_mul_pd(betaIv, a_perm[0]); + a_perm[1] = _mm512_mul_pd(betaIv, a_perm[1]); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm[0]); + a_vec[1] = _mm512_fmaddsub_pd(betaRv, a_vec[1], a_perm[1]); + + c_vec[0] = _mm512_add_pd(a_vec[0], c_vec[0]); + c_vec[1] = _mm512_add_pd(a_vec[1], c_vec[1]); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec[0]); + _mm512_storeu_pd((double *)(temp_cij + 4), c_vec[1]); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 8; + temp_ai += 8; + m_rem -= 8; + } + + if( m_rem >= 4 ) + { + __m512d a_perm, c_vec; + __m512d betaRv, betaIv; + + // Clearing the scratch register for accumalation + c_vec = _mm512_setzero_pd(); + + // Loading 4 elements from A + a_vec[0] = _mm512_loadu_pd((double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec = _mm512_mul_pd(bdcst_imag[0], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_storeu_pd((double *)(temp_cij), c_vec); + break; + + case BLIS_MUL_ONE : + // Loading from C + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Adding alpha*A*b to C + c_vec = _mm512_add_pd(c_vec, a_vec[0]); + + // Storing to C + _mm512_storeu_pd((double *)(temp_cij), c_vec); + break; + + default : + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_loadu_pd((double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec = _mm512_add_pd(a_vec[0], c_vec); + + // Storing the result to memory + _mm512_storeu_pd((double *)(temp_cij), c_vec); + } + + // Adjusting the addresses of A and C for the next iteration. + temp_cij += 4; + temp_ai += 4; + + m_rem -= 4; + } + + if( m_rem > 0 ) + { + __mmask8 m_mask = m_mask = (1 << 2 * m_rem) - 1; + __m512d a_perm, c_vec; + __m512d betaRv, betaIv; + + // Clearing the scratch register + c_vec = _mm512_setzero_pd(); + + // Loading the remaining elements from A + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)temp_ai); + + // Swapping real and imag components, to be used in computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag components of alpha*B + c_vec = _mm512_mul_pd(bdcst_imag[0], a_perm); + + // Scaling with real comp of alpha*B and accumulating + c_vec = _mm512_fmaddsub_pd(bdcst_real[0], a_vec[0], c_vec); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) + { + case BLIS_MUL_ZERO : + // Storing the result in C. + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec); + break; + + case BLIS_MUL_ONE : + // Loading from C + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Adding alpha*A*b to C + c_vec = _mm512_add_pd(c_vec, a_vec[0]); + + // Storing to C + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec); + break; + + default : + betaRv = _mm512_set1_pd(beta_real); + betaIv = _mm512_set1_pd(beta_imag); + + // Load C from memory + a_vec[0] = _mm512_maskz_loadu_pd(m_mask, (double const*)(temp_cij)); + + // Swapping real and imag parts of C for computation + a_perm = _mm512_permute_pd(a_vec[0], 0x55); + + // Scaling with imag component of beta + a_perm = _mm512_mul_pd(betaIv, a_perm); + + // Scaling with real component of beta and accumulating + a_vec[0] = _mm512_fmaddsub_pd(betaRv, a_vec[0], a_perm); + + c_vec = _mm512_add_pd(a_vec[0], c_vec); + + // Storing the result to memory + _mm512_mask_storeu_pd((double *)(temp_cij), m_mask, c_vec); + } + } + } + + return BLIS_SUCCESS; +} diff --git a/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c b/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c index fd0181c1d1..4b35608cd6 100644 --- a/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c +++ b/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c b/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c index 5fe475421e..1f4789f69a 100644 --- a/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c +++ b/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c b/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c index 8e86e2040c..dc20892d9d 100644 --- a/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c +++ b/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/bli_ztrsm_small_AVX512.c b/kernels/zen4/3/bli_ztrsm_small_AVX512.c new file mode 100644 index 0000000000..ab1ce4551c --- /dev/null +++ b/kernels/zen4/3/bli_ztrsm_small_AVX512.c @@ -0,0 +1,1062 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +#include "immintrin.h" + +#if defined __clang__ + #define UNROLL_LOOP() _Pragma("clang loop unroll_count(4)") + /* + * in clang, unroll_count(4) generates inefficient + * code compared to unroll(full) when loopCount = 4. + */ + #define UNROLL_LOOP_FULL() _Pragma("clang loop unroll(full)") +#elif defined __GNUC__ + #define UNROLL_LOOP() _Pragma("GCC unroll 4") + #define UNROLL_LOOP_FULL() _Pragma("GCC unroll 4") +#else + #define UNROLL_LOOP() + #define UNROLL_LOOP_FULL() +#endif + +/* +* Multiply dcomplex vector with a dcomplex scalar(S) +* reg_a -> input dcomplex vector +* reg_r -> vector with S->real broadcasted +* reg_i -> vector with S->imag broadcasted +* output -> vector where output is stored +* +* t_reg[5] contains [1, 1, 1, 1, 1, 1, 1, 1] +* +* (a + ib) (c + id) = (ac - bd) + i(ad + bc) +* here reg_a = [a1, b1, a2, b2, a3, b3, a4, b4] +* reg_r = [c, c, c, c, c, c, c, c ] +* reg_i = [d, d, d, d, d, d, d, d ] +*/ +#define MULTIPLY_COMPLEX( reg_a, reg_r, reg_i, output ) \ + t_reg[3] = _mm512_permute_pd(reg_a, 0x55); \ + /* t_reg[3] = [b1, a1, b2, a2, b3, a3, b4, a4] */ \ + output = _mm512_mul_pd(reg_a, reg_r); \ + /* output = c * [a1, b1, a2, b2, a3, b3, a4, b4]*/ \ + t_reg[3] = _mm512_mul_pd(t_reg[3], reg_i); \ + /* t_reg[3] = d * [b1, a1, b2, a2, b3, a3, b4, a4]*/ \ + output = _mm512_fmaddsub_pd(t_reg[5], output, t_reg[3]); \ + /* output = [a1c-b1d, a1d+b1c, a2c-b2d, a2d+b2c, ......]*/ \ + +/* +* Divide dcomplex vector with a dcomplex scalar(S) +* reg_a -> input dcomplex vector +* addr -> address of scalar +* output is stored in reg_a +* +* t_teg[4] contains [-1, -1, -1, -1, -1, -1, -1, -1] +* t_reg[5] contains [ 1, 1, 1, 1, 1, 1, 1, 1] +* +* (xr + i xi)/(ar + i ai) = +* (xrar + xiai)/(ar^2 + ai^2) + +* i(xiar - xrai)/(ar^2 + ai^2) +* +* instead if dividing by ar^2 + ai^2, we divide +* by ar/maxabs(ar, ai) * ar + ai / maxabs(ar, ai) * ai +* in order to reduce the possibility of underflow +* when c or d are very small +* +* here reg_a = [a1, b1, a2, b2, a3, b3, a4, b4] +*/ +#define DIVIDE_COMPLEX( reg_a, addr ) \ + g_double[2] = bli_fmaxabs(addr->real, addr->imag);/*s*/ \ + g_double[0] = addr->real / g_double[2];/*ar/s*/ \ + g_double[1] = addr->imag / g_double[2];/*ai/s*/ \ + t_reg[0] = _mm512_set1_pd(g_double[0]);/*ar/s*/ \ + t_reg[1] = _mm512_set1_pd(g_double[1]);/*ai/s*/ \ + g_double[2] = (g_double[0] * addr->real) + \ + (g_double[1] * addr->imag); \ + /*(ar/s * ar) +(ai/s * ai)*/ \ + t_reg[3] = _mm512_permute_pd(reg_a, 0x55); \ + /*t_reg[3] = [xi,xr,xi,xr....] */ \ + reg_a = _mm512_mul_pd(reg_a, t_reg[0]); \ + /* reg_a = ar/s * [xr, xi, xr, xi ....]*/ \ + t_reg[3] = _mm512_mul_pd(t_reg[3], t_reg[1]); \ + /*t_reg[3] = ai/s * [xi,xr,xi,xr........] */ \ + t_reg[3] = _mm512_mul_pd(t_reg[4], t_reg[3]); \ + /*t_reg[3] = -ai/s * [xi,xr,xi,xr........] */ \ + t_reg[1] = _mm512_set1_pd(g_double[2]); \ + /*t_reg[1] = [(ar/s * ar) +(ai/s * ai), ...] */ \ + reg_a = _mm512_fmaddsub_pd(t_reg[5], reg_a, t_reg[3]);\ + /*reg_a = [a1c+b1d, b1c-a1d, a2c+b2d, b2c-a2d, ....]*/ \ + reg_a = _mm512_div_pd(reg_a, t_reg[1]); \ + +// Zero the registors used for gemm accumulation +#define ZERO_REGISTERS() \ + c_reg[0] = _mm512_setzero_pd(); \ + c_reg[1] = _mm512_setzero_pd(); \ + c_reg[2] = _mm512_setzero_pd(); \ + c_reg[3] = _mm512_setzero_pd(); \ + c_reg[4] = _mm512_setzero_pd(); \ + c_reg[5] = _mm512_setzero_pd(); \ + c_reg[6] = _mm512_setzero_pd(); \ + c_reg[7] = _mm512_setzero_pd(); \ + t_reg[5] = _mm512_setzero_pd(); \ + b_reg[0] = _mm512_setzero_pd(); \ + b_reg[1] = _mm512_setzero_pd(); \ + b_reg[2] = _mm512_setzero_pd(); \ + b_reg[3] = _mm512_setzero_pd(); \ + +/* Initialize variable which are +* common across all kernels. +*/ +#define INIT() \ + __m512d t_reg[6]; /*temporary registers*/ \ + __m512d c_reg[8]; /*registors to hold GEMM accumulation*/\ + __m512d b_reg[4]; /*registors to hold B matrix*/ \ + t_reg[5] = _mm512_set1_pd( 1.0 ); /*(constant) used for fmaddsub*/\ + \ + double g_double[3]; \ + __mmask8 mask_m; /*registor to hold mask for laod/store*/\ + \ + dim_t m = bli_obj_length( b ); \ + dim_t n = bli_obj_width( b ); \ + dim_t cs_a = bli_obj_col_stride( a ); \ + dim_t rs_a = bli_obj_row_stride( a ); \ + dim_t cs_b = bli_obj_col_stride( b ); \ + \ + bool transa = bli_obj_has_trans( a ); \ + bool is_unitdiag = bli_obj_has_unit_diag( a ); \ + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; \ + \ + dim_t d_mr = 4; \ + dim_t d_nr = 4; \ + dim_t i, j; \ + dim_t k_iter; \ + \ + dcomplex* restrict L = bli_obj_buffer_at_off( a ); \ + dcomplex* restrict B = bli_obj_buffer_at_off( b ); \ + +/* +* Perform GEMM with given value of M, N, K +* K is always a multiple of 4 +* N is compile time constant. +* M <= 4 and N <= 4. +* Output is stored in registor c_reg[0] to c_reg[N-1] +*/ +#define GEMM_MxN( a01_, b10_, rs_a_, cs_a_, cs_b_, k_iter_, M_, N_ ) \ + \ + UNROLL_LOOP() \ + for( dim_t ii = 0; ii < k_iter_; ++ii ) \ + { \ + b_reg[0] = _mm512_mask_loadu_pd(c_reg[0], mask_m, b10_); \ + UNROLL_LOOP_FULL() \ + for( dim_t jj = 0; jj < N_; ++jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a01_ + cs_a_*jj)->real); \ + t_reg[1] = _mm512_set1_pd((a01_ + cs_a_*jj)->imag); \ + c_reg[jj] = _mm512_fmadd_pd(t_reg[0], b_reg[0], c_reg[jj]); \ + c_reg[jj+4] = _mm512_fmadd_pd(t_reg[1], b_reg[0], c_reg[jj+4]); \ + } \ + a01_ += rs_a_; \ + b10_ += cs_b_; \ + } \ + t_reg[5] = _mm512_set1_pd(1.0); \ + UNROLL_LOOP_FULL() \ + for ( dim_t jj = 0; jj < N_; ++jj ) \ + { \ + c_reg[jj+4] = _mm512_permute_pd(c_reg[jj+4], 0x55); \ + c_reg[jj] = _mm512_fmaddsub_pd(t_reg[5], c_reg[jj], c_reg[jj+4]); \ + } \ + +/* +* Performs alpha*B - gemm_output +* N is compile time constant. +* M <= 4 and N <= 4. +*/ +#define PRE_TRSM_NxM(AlphaVal, b11, cs_b, M, N) \ + \ + if(AlphaVal.real == 1 && AlphaVal.imag == 0) \ + { \ + UNROLL_LOOP_FULL() \ + for(int ii=0; iireal); \ + t_reg[1] = _mm512_set1_pd((a11 + jj*cs_a)->imag); \ + MULTIPLY_COMPLEX(c_reg[ii], t_reg[0], t_reg[1], t_reg[2]) \ + c_reg[jj] = _mm512_sub_pd(c_reg[jj], t_reg[2]); \ + } \ + a11 += rs_a; \ + } \ + +/* +* Perform TRSM computation for Right Lower +* NonTranpose variant. +* N is compile time constant. +*/ +#define TRSM_MAIN_RLNN_NXM(N) \ + \ + a11 += rs_a * (N-1); \ + UNROLL_LOOP_FULL() \ + for( dim_t ii = (N-1); ii >= 0; --ii ) \ + { \ + if( !is_unitdiag ) \ + { \ + DIVIDE_COMPLEX(c_reg[ii], (a11 + ii*cs_a)) \ + } \ + UNROLL_LOOP_FULL() \ + for( dim_t jj = (ii-1); jj >= 0; --jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a11 + jj*cs_a)->real); \ + t_reg[1] = _mm512_set1_pd((a11 + jj*cs_a)->imag); \ + MULTIPLY_COMPLEX(c_reg[ii], t_reg[0], t_reg[1], t_reg[2]) \ + c_reg[jj] = _mm512_sub_pd(c_reg[jj], t_reg[2]); \ + } \ + a11 -= rs_a; \ + } \ + +/* +* Stores output from registors(c_reg) to memory(B) +* n is a compile time constant. +*/ +#define STORE_RIGHT_C( n ) \ + UNROLL_LOOP_FULL() \ + for ( dim_t ii=0; ii < n; ++ii ) \ + { \ + _mm512_mask_storeu_pd((b11 + (ii * cs_b)), mask_m, c_reg[ii]); \ + } \ + +/* +* Perform GEMM + TRSM computation for Right Upper NonTranpose +* +* +* Left shift 1 by M times will set (M+1)th least significant bit +* subtracting 1 from that will unset (M+1)th LSB and set last M lSBs +* +* Example: 1 << 4 = 0b00010000 +* ( 1 << 4 ) - 1 = 0b00001111 +*/ +#define RUNN_FRINGE( M, N ) \ + mask_m = (1 << (M*2)) - 1; \ + \ + a01 = L + j*cs_a; \ + a11 = L + j*cs_a + j*rs_a; \ + b10 = B + i; \ + b11 = B + i + j*cs_b; \ + k_iter = j; \ + \ + ZERO_REGISTERS() \ + \ + GEMM_MxN( a01, b10, rs_a, cs_a, cs_b, k_iter, M, N ) \ + PRE_TRSM_NxM( AlphaVal, b11, cs_b, M, N ) \ + \ + t_reg[4] = _mm512_set1_pd(-1.0); \ + TRSM_MAIN_RUN_NxM( N ) \ + STORE_RIGHT_C( N ) \ + +/* +* Perform GEMM + TRSM computation for Right Lower NonTranpose +*/ +#define RLNN_FRINGE( M, N ) \ + mask_m = (1 << (M*2)) - 1; \ + \ + a01 = L + ((j - N + d_nr) * cs_a) + (j + d_nr) * rs_a; \ + a11 = L + (j - N + d_nr) * rs_a + (j - N + d_nr) * cs_a; \ + b10 = B + (i - M + d_mr) + (j + d_nr) * cs_b; \ + b11 = B + (i - M + d_mr) + (j - N + d_nr) * cs_b; \ + k_iter = (n - j - d_nr); \ + \ + ZERO_REGISTERS() \ + GEMM_MxN( a01, b10, rs_a, cs_a, cs_b, k_iter, M, N ) \ + PRE_TRSM_NxM( AlphaVal, b11, cs_b, M, N ) \ + \ + t_reg[4] = _mm512_set1_pd(-1.0); \ + TRSM_MAIN_RLNN_NXM( N ) \ + STORE_RIGHT_C( N ) \ + +; + +/* +* Solves Right Upper NonTranspose TRSM when N < 4 +*/ +BLIS_INLINE void runn_n_rem + ( + dim_t i, + dim_t j, + dim_t cs_a, + dim_t rs_a, + dim_t cs_b, + dim_t m, + dim_t n, + dcomplex* L, + dcomplex* B, + dim_t k_iter, + bool transa, + dcomplex AlphaVal, + bool is_unitdiag + ) +{ + __m512d t_reg[6]; + __m512d c_reg[8]; + __m512d b_reg[4]; + + double g_double[3]; + __mmask8 mask_m; + + t_reg[5] = _mm512_set1_pd(1.0); + + dim_t d_mr = 4; + dcomplex *a01, *a11, *b10, *b11; + dim_t m_rem; + dim_t n_rem = n - j; + + /* + * Switch statements used here to make sure that + * N is a constant and compiler can unroll the loop + * at compile time. + */ + switch( n_rem ) + { + case 1: + for( i = 0; (i+d_mr-1) < m; i += d_mr ) + { + RUNN_FRINGE( 4, 1 ) + } + m_rem = m - i; + if( m_rem > 0 ) + { + RUNN_FRINGE( m_rem, 1 ) + } + break; + case 2: + for( i = 0; (i+d_mr-1) < m; i += d_mr ) + { + RUNN_FRINGE( 4, 2 ) + } + m_rem = m - i; + if( m_rem > 0 ) + { + RUNN_FRINGE( m_rem, 2 ) + } + break; + case 3: + for( i = 0; (i+d_mr-1) < m; i += d_mr ) + { + RUNN_FRINGE( 4, 3 ) + } + m_rem = m - i; + if( m_rem > 0 ) + { + RUNN_FRINGE( m_rem, 3 ) + } + break; + default: + break; + } +} + +// RUNN - RLTN +err_t bli_ztrsm_small_XAltB_XAuB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + INIT() + if( transa ) + { + /* + * If variants being solved is RLTN + * then after swapping rs_a and cs_a, + * problem will become same as RUNN + */ + i = cs_a; + cs_a = rs_a; + rs_a = i; + } + dcomplex *a01, *a11, *b10, *b11; + for( j = 0; (j+d_nr-1) < n; j += d_nr ) + { + for( i = 0; (i+d_mr-1) < m; i += d_mr ) + { + RUNN_FRINGE( 4, 4 ) + } + dim_t m_rem = m - i; + if( m_rem > 0 ) + { + RUNN_FRINGE( m_rem, 4 ) + } + } + dim_t n_rem = n - j; + if( n_rem > 0 ) + { + /* + * A hack: + * clang/aocc generate inefficient code when + * all M and N are handled in one function. + * (AOCC tries to make sure that each of the gemm call is + * using independent set of registors, which causes many + * read/writes in stack.) + * So part of code is moved to a seperate function. + */ + runn_n_rem + ( + i, j, + cs_a, rs_a, + cs_b, + m, n, + L, B, + k_iter, + transa, + AlphaVal, + is_unitdiag + ); + } + return BLIS_SUCCESS; +} + +/* +* Solves Right Upper NonTranspose TRSM when N < 4 +*/ +BLIS_INLINE void rlnn_n_rem + ( + dim_t i, dim_t j, + dim_t cs_a, dim_t rs_a, + dim_t cs_b, + dim_t m, dim_t n, + dcomplex* L, + dcomplex* B, + dim_t k_iter, + bool transa, + dcomplex AlphaVal, + bool is_unitdiag + ) +{ + __m512d t_reg[6]; + __m512d c_reg[8]; + __m512d b_reg[4]; + + double g_double[3]; + __mmask8 mask_m; + + t_reg[5] = _mm512_set1_pd(1.0); + dim_t d_mr = 4; + dim_t d_nr = 4; + + dcomplex *a01, *a11, *b10, *b11; + dim_t m_rem; + dim_t n_rem = j + d_nr; + + switch( n_rem ) + { + case 1: + for( i = (m - d_mr); (i + 1) > 0; i -= d_mr ) + { + RLNN_FRINGE( 4, 1 ) + } + m_rem = i + d_mr; + if( m_rem > 0 ) + { + RLNN_FRINGE( m_rem, 1 ) + } + break; + case 2: + for( i = (m - d_mr); (i + 1) > 0; i -= d_mr ) + { + RLNN_FRINGE( 4, 2 ) + } + m_rem = i + d_mr; + if( m_rem > 0 ) + { + RLNN_FRINGE( m_rem, 2 ) + } + break; + case 3: + for( i = (m - d_mr); (i + 1) > 0; i -= d_mr ) + { + RLNN_FRINGE( 4, 3 ) + } + m_rem = i + d_mr; + if( m_rem > 0 ) + { + RLNN_FRINGE( m_rem, 3 ) + } + break; + default: + break; + } +} + +// RLNN - RUTNs +err_t bli_ztrsm_small_XAutB_XAlB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + INIT() + if( transa ) + { + /* + * If variants being solved is RUTN + * then after swapping rs_a and cs_a, + * problem will become same as RLNN + */ + i = cs_a; + cs_a = rs_a; + rs_a = i; + } + dcomplex *a01, *a11, *b10, *b11; + + for ( j = (n - d_nr); j > -1; j -= d_nr ) + { + for ( i = (m - d_mr); (i + 1) > 0; i -= d_mr ) + { + RLNN_FRINGE( 4, 4 ) + } + dim_t m_rem = i + d_mr; + if( m_rem > 0 ) + { + RLNN_FRINGE( m_rem, 4 ) + } + } + dim_t n_rem = j + d_nr; + if( n_rem > 0 ) + { + rlnn_n_rem + ( + i, j, + cs_a, rs_a, + cs_b, + m, n, + L, B, + k_iter, + transa, + AlphaVal, + is_unitdiag + ); + } + return BLIS_SUCCESS; +} + +/* +* Perform a 4x4 Transpose +* Data is read from c_reg[0] to c[4] +* and stored back to same registors after transpose +*/ +#define TRANSPOSE4x4() \ + t_reg[0] = _mm512_shuffle_f64x2(c_reg[0], c_reg[1], 0b10001000); \ + t_reg[1] = _mm512_shuffle_f64x2(c_reg[2], c_reg[3], 0b10001000); \ + t_reg[2] = _mm512_shuffle_f64x2(c_reg[0], c_reg[1], 0b11011101); \ + t_reg[3] = _mm512_shuffle_f64x2(c_reg[2], c_reg[3], 0b11011101); \ + \ + c_reg[0] = _mm512_shuffle_f64x2(t_reg[0], t_reg[1], 0b10001000); \ + c_reg[2] = _mm512_shuffle_f64x2(t_reg[0], t_reg[1], 0b11011101); \ + c_reg[1] = _mm512_shuffle_f64x2(t_reg[2], t_reg[3], 0b10001000); \ + c_reg[3] = _mm512_shuffle_f64x2(t_reg[2], t_reg[3], 0b11011101); \ + + +/* +* Perform GEMM when B is stored in row major order, +* k_iter is a multiple of 4 +*/ +#define GEMM_MxN_LEFT_TRANSPOSE( a01_, b10_, rs_a_, cs_a_, rs_b_, k_iter_, M_, N_ ) \ + \ + for( dim_t ii=0; ii < k_iter_/4; ++ii ) \ + { \ + /* load 4x4 B */ \ + for( dim_t jj=0; jj < M_; ++jj ) \ + { \ + b_reg[jj] = _mm512_loadu_pd(b10_ + (jj*rs_b_)); \ + } \ + /* Transpose 4x4 B*/ \ + t_reg[0] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b10001000); \ + t_reg[1] = _mm512_shuffle_f64x2(b_reg[2], b_reg[3], 0b10001000); \ + t_reg[2] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b11011101); \ + t_reg[3] = _mm512_shuffle_f64x2(b_reg[2], b_reg[3], 0b11011101); \ + b_reg[0] = _mm512_shuffle_f64x2(t_reg[0], t_reg[1], 0b10001000); \ + b_reg[2] = _mm512_shuffle_f64x2(t_reg[0], t_reg[1], 0b11011101); \ + b_reg[1] = _mm512_shuffle_f64x2(t_reg[2], t_reg[3], 0b10001000); \ + b_reg[3] = _mm512_shuffle_f64x2(t_reg[2], t_reg[3], 0b11011101); \ + \ + /*Iter 1*/ \ + UNROLL_LOOP_FULL() \ + for( dim_t jj=0; jj < N_; ++jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a01_ + cs_a_*jj)->real); \ + t_reg[1] = _mm512_set1_pd((a01_ + cs_a_*jj)->imag); \ + c_reg[jj] = _mm512_fmadd_pd(t_reg[0], b_reg[0], c_reg[jj]); \ + c_reg[jj+4] = _mm512_fmadd_pd(t_reg[1], b_reg[0], c_reg[jj+4]); \ + } \ + a01_ += rs_a_; \ + /*Iter 2*/ \ + UNROLL_LOOP_FULL() \ + for( dim_t jj=0; jj < N_; ++jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a01_ + cs_a_*jj)->real); \ + t_reg[1] = _mm512_set1_pd((a01_ + cs_a_*jj)->imag); \ + c_reg[jj] = _mm512_fmadd_pd(t_reg[0], b_reg[1], c_reg[jj]); \ + c_reg[jj+4] = _mm512_fmadd_pd(t_reg[1], b_reg[1], c_reg[jj+4]); \ + } \ + a01_ += rs_a_; \ + /*Iter 3*/ \ + UNROLL_LOOP_FULL() \ + for( dim_t jj=0; jj < N_; ++jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a01_ + cs_a_*jj)->real); \ + t_reg[1] = _mm512_set1_pd((a01_ + cs_a_*jj)->imag); \ + c_reg[jj] = _mm512_fmadd_pd(t_reg[0], b_reg[2], c_reg[jj]); \ + c_reg[jj+4] = _mm512_fmadd_pd(t_reg[1], b_reg[2], c_reg[jj+4]); \ + } \ + a01_ += rs_a_; \ + /*Iter 4*/ \ + UNROLL_LOOP_FULL() \ + for( dim_t jj=0; jj < N_; ++jj ) \ + { \ + t_reg[0] = _mm512_set1_pd((a01_ + cs_a_*jj)->real); \ + t_reg[1] = _mm512_set1_pd((a01_ + cs_a_*jj)->imag); \ + c_reg[jj] = _mm512_fmadd_pd(t_reg[0], b_reg[3], c_reg[jj]); \ + c_reg[jj+4] = _mm512_fmadd_pd(t_reg[1], b_reg[3], c_reg[jj+4]); \ + } \ + a01_ += rs_a_; \ + b10_ += 4; \ + } \ + t_reg[5] = _mm512_set1_pd(1.0); \ + UNROLL_LOOP_FULL() \ + for ( dim_t jj=0; jj < N_; ++jj ) \ + { \ + c_reg[jj+4] = _mm512_permute_pd(c_reg[jj+4], 0x55); \ + c_reg[jj] = _mm512_fmaddsub_pd(t_reg[5], c_reg[jj], c_reg[jj+4]); \ + } \ + +/* +* Perform GEMM + TRSM computation for Left Lower NonTranpose +* When Problem is LLNN, after a induced transpose problem +* becomes RUNN +*/ +#define LLNN_FRINGE( M, N ) \ + a10 = L + (i * cs_a); \ + a11 = L + (i * rs_a) + (i * cs_a); \ + b01 = B + j * cs_b; \ + b11 = B + i + j * cs_b; \ + \ + k_iter = i; \ + mask_m = (1 << (M*2)) - 1; \ + \ + ZERO_REGISTERS() \ + if (!transa) { \ + /*A and B are swapped are induced transpose*/ \ + GEMM_MxN( b01, a10, 1, cs_b, rs_a, k_iter, _, N ) \ + } else { \ + GEMM_MxN_LEFT_TRANSPOSE( b01, a10, 1, cs_b, cs_a, k_iter, M, N ) \ + } \ + PRE_TRSM_NxM( AlphaVal, b11, cs_b, _, N ) \ + /* + * RUNN kernel requires GEMM output to + * be in column major order + */ \ + TRANSPOSE4x4() \ + t_reg[4] = _mm512_set1_pd(-1.0); \ + TRSM_MAIN_RUN_NxM(M) \ + TRANSPOSE4x4() \ + STORE_RIGHT_C(N) \ + +/* +* Perform GEMM + TRSM computation for Left Upper NonTranpose +*/ +#define LUNN_FRINGE( M, N ) \ + mask_m = (1 << (M*2)) - 1; \ + \ + a10 = L + ((i - M + d_mr) * cs_a) + (i + d_nr) * rs_a; \ + a11 = L + (i - M + d_mr) * rs_a + (i - M + d_nr) * cs_a; \ + b01 = B + (i + d_mr) + (j - N + d_nr) * cs_b; \ + b11 = B + (i - M + d_mr) + (j - N + d_nr) * cs_b; \ + k_iter = ( m - i - d_mr ); \ + \ + ZERO_REGISTERS() \ + if (!transa) { \ + GEMM_MxN( b01, a10, 1, cs_b, rs_a, k_iter, _, N ) \ + } else { \ + GEMM_MxN_LEFT_TRANSPOSE( b01, a10, 1, cs_b, cs_a, k_iter, M, N ) \ + } \ + \ + PRE_TRSM_NxM( AlphaVal, b11, cs_b, _, N ) \ + TRANSPOSE4x4() \ + t_reg[4] = _mm512_set1_pd(-1.0); \ + TRSM_MAIN_RLNN_NXM( M ) \ + TRANSPOSE4x4() \ + STORE_RIGHT_C( N ) \ + +/* +* Solves Left Lower NonTranspose TRSM when M < 4 +*/ +BLIS_INLINE void llnn_m_rem + ( + dim_t i, dim_t j, + dim_t cs_a, dim_t rs_a, + dim_t cs_b, + dim_t m, dim_t n, + dcomplex* L, + dcomplex* B, + dim_t k_iter, + bool transa, + dcomplex AlphaVal, + bool is_unitdiag + ) +{ + __m512d t_reg[6]; + __m512d c_reg[8]; + __m512d b_reg[4]; + double g_double[3]; + + __mmask8 mask_m; + t_reg[5] = _mm512_set1_pd(1.0); + + dim_t d_nr = 4; + dcomplex *a10, *a11, *b01, *b11; + dim_t m_rem = m - i; + dim_t n_rem; + + switch( m_rem ) + { + case 1: + for( j = 0; (j + d_nr - 1) < n; j += d_nr ) + { + LLNN_FRINGE( 1, 4 ) + } + n_rem = n - j; + switch( n_rem ) + { + case 1: + LLNN_FRINGE( 1, 1 ); break; + case 2: + LLNN_FRINGE( 1, 2 ); break; + case 3: + LLNN_FRINGE( 1, 3 ); break; + default: + break; + } + break; + case 2: + for( j = 0; (j + d_nr - 1) < n; j += d_nr ) + { + LLNN_FRINGE( 2, 4 ) + } + n_rem = n - j; + switch( n_rem ) + { + case 1: + LLNN_FRINGE( 2, 1 ); break; + case 2: + LLNN_FRINGE( 2, 2 ); break; + case 3: + LLNN_FRINGE( 2, 3 ); break; + default: + break; + } + break; + case 3: + for( j = 0; (j + d_nr - 1) < n; j += d_nr ) + { + LLNN_FRINGE( 3, 4 ) + } + n_rem = n - j; + switch( n_rem ) + { + case 1: + LLNN_FRINGE( 3, 1 ); break; + case 2: + LLNN_FRINGE( 3, 2 ); break; + case 3: + LLNN_FRINGE( 3, 3 ); break; + default: + break; + } + break; + default: + break; + } +} + +// LLNN - LUTN +err_t bli_ztrsm_small_AutXB_AlXB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + INIT() + if( !transa ) + { + i = cs_a; + cs_a = rs_a; + rs_a = i; + } + dcomplex *a10, *a11, *b01, *b11; + for( i = 0; (i + d_mr - 1) < m; i += d_mr ) + { + for( j = 0; j < n - d_nr + 1; j += d_nr ) + { + LLNN_FRINGE( 4, 4 ) + } + dim_t n_rem = n - j; + if( n_rem > 0 ) + { + switch( n_rem ) + { + case 1: + LLNN_FRINGE( 4, 1 ); break; + case 2: + LLNN_FRINGE( 4, 2 ); break; + case 3: + LLNN_FRINGE( 4, 3 ); break; + default: + break; + } + } + } + dim_t m_rem = m - i; + if( m_rem > 0 ) + { + llnn_m_rem + ( + i, j, + cs_a, rs_a, + cs_b, + m, n, + L, B, + k_iter, + transa, + AlphaVal, + is_unitdiag + ); + } + return BLIS_SUCCESS; +} + +/* +* Solves Left Upper NonTranspose TRSM when M < 4 +*/ +BLIS_INLINE void lunn_m_rem + ( + dim_t i, dim_t j, + dim_t cs_a, dim_t rs_a, + dim_t cs_b, + dim_t m, dim_t n, + dcomplex* L, + dcomplex* B, + dim_t k_iter, + bool transa, + dcomplex AlphaVal, + bool is_unitdiag + ) +{ + __m512d t_reg[6]; + __m512d c_reg[8]; + __m512d b_reg[4]; + + double g_double[3]; + __mmask8 mask_m; + + t_reg[5] = _mm512_set1_pd(1.0); + dim_t d_mr = 4; + dim_t d_nr = 4; + dcomplex *a10, *a11, *b01, *b11; + dim_t m_rem = i + d_mr; + dim_t n_rem; + + switch( m_rem ) + { + case 1: + for( j = (n - d_nr); (j + 1) > 0; j -= d_nr ) + { + LUNN_FRINGE( 1, 4 ) + } + n_rem = j + d_nr; + switch( n_rem ) + { + case 1: + LUNN_FRINGE( 1, 1 ); break; + case 2: + LUNN_FRINGE( 1, 2 ); break; + case 3: + LUNN_FRINGE( 1, 3 ); break; + default: + break; + } + break; + case 2: + for( j = (n - d_nr); (j + 1) > 0; j -= d_nr ) + { + LUNN_FRINGE( 2, 4 ) + } + n_rem = j + d_nr; + switch( n_rem ) + { + case 1: + LUNN_FRINGE( 2, 1 ); break; + case 2: + LUNN_FRINGE( 2, 2 ); break; + case 3: + LUNN_FRINGE( 2, 3 ); break; + default: + break; + } + break; + case 3: + for( j = (n - d_nr); (j + 1) > 0; j -= d_nr ) + { + LUNN_FRINGE( 3, 4 ) + } + n_rem = j + d_nr; + switch( n_rem ) + { + case 1: + LUNN_FRINGE( 3, 1 ); break; + case 2: + LUNN_FRINGE( 3, 2 ); break; + case 3: + LUNN_FRINGE( 3, 3 ); break; + default: + break; + } + break; + default: + break; + } +} + +// LUNN - LLTN +err_t bli_ztrsm_small_AltXB_AuXB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + INIT() + if( !transa ) + { + i = cs_a; + cs_a = rs_a; + rs_a = i; + } + dcomplex *a10, *a11, *b01, *b11; + for( i = (m - d_mr); (i + 1) > 0; i -= d_mr ) + { + for( j = (n - d_nr); (j + 1) > 0; j -= d_nr ) + { + LUNN_FRINGE( 4, 4 ) + } + dim_t n_rem = j + d_nr; + if( n_rem > 0 ) + { + switch( n_rem ) + { + case 1: + LUNN_FRINGE( 4, 1 ); break; + case 2: + LUNN_FRINGE( 4, 2 ); break; + case 3: + LUNN_FRINGE( 4, 3 ); break; + default: + break; + } + } + } + dim_t m_rem = i + d_mr; + if( m_rem > 0 ) + { + lunn_m_rem + ( + i, j, + cs_a, rs_a, + cs_b, + m, n, + L, B, + k_iter, + transa, + AlphaVal, + is_unitdiag + ); + } + return BLIS_SUCCESS; +} + +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM + diff --git a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c index 649aa416b5..8ee1ef8e08 100644 --- a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c +++ b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" @@ -1780,10 +1780,8 @@ void bli_dgemmsup_rv_zen4_asm_24x8m [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0) - : // register clobber list + [cs_c] "m" (cs_c) + : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm2", "xmm31", @@ -3277,8 +3275,6 @@ void bli_dgemmsup_rv_zen4_asm_24x7m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -4662,8 +4658,6 @@ void bli_dgemmsup_rv_zen4_asm_24x6m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -5930,8 +5924,6 @@ void bli_dgemmsup_rv_zen4_asm_24x5m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -7071,8 +7063,6 @@ void bli_dgemmsup_rv_zen4_asm_24x4m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -8088,8 +8078,6 @@ void bli_dgemmsup_rv_zen4_asm_24x3m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -8983,8 +8971,6 @@ void bli_dgemmsup_rv_zen4_asm_24x2m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -9754,8 +9740,6 @@ void bli_dgemmsup_rv_zen4_asm_24x1m [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", diff --git a/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c b/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c index 4fc04901ca..f58ffd179b 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -19,14 +19,14 @@ from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c index 96fa63e95d..c8b857eab1 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" @@ -94,7 +96,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -102,7 +104,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a @@ -130,7 +132,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -152,7 +154,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -175,7 +177,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -198,7 +200,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -233,7 +235,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -256,7 +258,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -268,7 +270,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -276,7 +278,9 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -284,20 +288,23 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovups( ( rax, r8, 4 ), ymm4 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA5( 14, 15, 16, 26, 27 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA5( 17, 18, 19, 29, 30 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -309,7 +316,11 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 je( .POST_ACCUM ) label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -317,12 +328,17 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 vmovss( ( rax, r8, 4 ), xmm4 ) add( imm( 1*4 ), rax ) + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -360,7 +376,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + // Accumulates the results by horizontally adding the YMM registers, // and having the final result in xmm registers. ACCUM_YMM( 4, 7, 10, 13, 4 ) @@ -390,7 +406,7 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -524,7 +540,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -532,7 +548,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -558,7 +574,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -580,7 +596,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -603,7 +619,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -625,7 +641,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -659,7 +675,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -681,7 +697,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -693,7 +709,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -701,27 +717,32 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) vmovups( ( rax, r10, 1 ), ymm3 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA4( 8, 9, 10, 20 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA4( 14, 15, 16, 26 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA4( 17, 18, 19, 29 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -733,19 +754,28 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 je( .POST_ACCUM ) label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) vmovss( ( rax, r10, 1 ), xmm3 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA4( 8, 9, 10, 20 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -761,7 +791,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 mov( var( beta ), rax ) // load address of beta vbroadcastss( ( rax ), xmm0 ) - + vxorps( xmm1, xmm1, xmm1 ) vucomiss( xmm1, xmm0 ) // check if beta = 0 @@ -772,7 +802,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -796,7 +826,7 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -931,7 +961,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -939,7 +969,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -965,7 +995,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -986,7 +1016,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -1007,7 +1037,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -1028,7 +1058,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -1063,7 +1093,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -1084,7 +1114,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -1096,7 +1126,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -1105,26 +1135,31 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) add( imm( 8*4 ), rax ) - // load column from B - vmovups( ( rbx ), zmm6 ) + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro + vmovups( ( rbx ), ymm6 ) VFMA3( 8, 9, 10 ) - vmovups( ( rbx, r9, 1 ), zmm6 ) + vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA3( 11, 12, 13 ) - - vmovups( ( rbx, r9, 2 ), zmm6 ) + + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA3( 14, 15, 16 ) - vmovups( ( rbx, r13, 1 ), zmm6 ) + vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA3( 17, 18, 19 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -1136,18 +1171,27 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 je( .POST_ACCUM ) label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA3( 8, 9, 10 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA3( 11, 12, 13 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA3( 14, 15, 16 ) @@ -1175,7 +1219,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -1193,7 +1237,7 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -1318,7 +1362,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -1326,10 +1370,10 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG @@ -1352,7 +1396,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1372,7 +1416,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1393,7 +1437,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1413,7 +1457,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1447,7 +1491,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1467,7 +1511,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -1479,7 +1523,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -1488,25 +1532,30 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA2( 8, 9 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA2( 14, 15 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA2( 17, 18 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -1519,17 +1568,26 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA2( 8, 9 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA2( 11, 12 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA2( 14, 15 ) @@ -1572,7 +1630,7 @@ void bli_sgemmsup_rd_zen_asm_2x64_avx512 ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 ) ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) @@ -1686,7 +1744,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -1694,10 +1752,10 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG @@ -1719,7 +1777,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1738,7 +1796,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1757,7 +1815,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1776,7 +1834,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1808,7 +1866,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1827,7 +1885,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -1839,7 +1897,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -1848,24 +1906,29 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA1( 8 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA1( 14 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA1( 17 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -1878,16 +1941,25 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA1( 8 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA1( 11 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA1( 14 ) @@ -1913,7 +1985,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 label( .POST_ACCUM_STOR ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha @@ -1927,7 +1999,7 @@ void bli_sgemmsup_rd_zen_asm_1x64_avx512 label( .POST_ACCUM_STOR_BZ ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha @@ -2040,7 +2112,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -2048,7 +2120,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -2076,7 +2148,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2099,7 +2171,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2123,7 +2195,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2146,7 +2218,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2183,7 +2255,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2206,7 +2278,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2218,7 +2290,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -2227,7 +2299,9 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -2235,20 +2309,23 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovups( ( rax, r8, 4 ), ymm4 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA5( 14, 15, 16, 26, 27 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA5( 17, 18, 19, 29, 30 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -2261,7 +2338,11 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -2269,12 +2350,17 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 vmovss( ( rax, r8, 4 ), xmm4 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -2302,7 +2388,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -2330,7 +2416,7 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -2463,7 +2549,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -2471,7 +2557,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -2497,7 +2583,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2519,7 +2605,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2542,7 +2628,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2564,7 +2650,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2600,7 +2686,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2622,7 +2708,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2634,7 +2720,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -2643,27 +2729,32 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) vmovups( ( rax, r10, 1 ), ymm3 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA4( 8, 9, 10, 20 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA4( 14, 15, 16, 26 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA4( 17, 18, 19, 29 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -2676,19 +2767,28 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) vmovss( ( rax, r10, 1 ), xmm3 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA4( 8, 9, 10, 20 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -2715,7 +2815,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -2741,7 +2841,7 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -2876,7 +2976,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -2884,7 +2984,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -2910,7 +3010,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -2931,7 +3031,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -2952,7 +3052,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -2973,7 +3073,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -3008,7 +3108,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -3029,7 +3129,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -3041,7 +3141,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -3050,26 +3150,31 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) add( imm( 8*4 ), rax ) - // load column from B - vmovups( ( rbx ), zmm6 ) + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro + vmovups( ( rbx ), ymm6 ) VFMA3( 8, 9, 10 ) - vmovups( ( rbx, r9, 1 ), zmm6 ) + vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA3( 11, 12, 13 ) - - vmovups( ( rbx, r9, 2 ), zmm6 ) + + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA3( 14, 15, 16 ) - vmovups( ( rbx, r13, 1 ), zmm6 ) + vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA3( 17, 18, 19 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -3082,18 +3187,27 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA3( 8, 9, 10 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA3( 11, 12, 13 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA3( 14, 15, 16 ) @@ -3121,7 +3235,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -3139,7 +3253,7 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -3263,7 +3377,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -3271,10 +3385,10 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG @@ -3297,7 +3411,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3317,7 +3431,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3338,7 +3452,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3358,7 +3472,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3392,7 +3506,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3412,7 +3526,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -3424,7 +3538,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -3433,25 +3547,30 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA2( 8, 9 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA2( 14, 15 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA2( 17, 18 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -3464,17 +3583,26 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA2( 8, 9 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA2( 11, 12 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA2( 14, 15 ) @@ -3501,7 +3629,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 ) ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) @@ -3517,7 +3645,7 @@ void bli_sgemmsup_rd_zen_asm_2x48_avx512 ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 ) ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) @@ -3633,7 +3761,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -3641,10 +3769,10 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG mov( var( k_iter64 ), rsi ) // load k_iter @@ -3665,7 +3793,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3684,7 +3812,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3703,7 +3831,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3722,7 +3850,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3755,7 +3883,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3774,7 +3902,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -3786,7 +3914,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -3795,24 +3923,29 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA1( 8 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA1( 14 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA1( 17 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -3825,16 +3958,25 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA1( 8 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA1( 11 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA1( 14 ) @@ -3859,7 +4001,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 label( .POST_ACCUM_STOR ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha @@ -3873,7 +4015,7 @@ void bli_sgemmsup_rd_zen_asm_1x48_avx512 label( .POST_ACCUM_STOR_BZ ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha @@ -3985,7 +4127,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -3993,7 +4135,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -4021,7 +4163,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4044,7 +4186,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4068,7 +4210,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4091,7 +4233,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4127,7 +4269,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4150,7 +4292,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4162,7 +4304,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -4171,7 +4313,9 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -4179,20 +4323,23 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovups( ( rax, r8, 4 ), ymm4 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA5( 14, 15, 16, 26, 27 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA5( 17, 18, 19, 29, 30 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -4205,7 +4352,11 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -4213,12 +4364,17 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 vmovss( ( rax, r8, 4 ), xmm4 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA5( 8, 9, 10, 20, 21 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA5( 11, 12, 13, 23, 24 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA5( 14, 15, 16, 26, 27 ) @@ -4246,7 +4402,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -4274,7 +4430,7 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -4408,7 +4564,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -4416,7 +4572,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -4443,7 +4599,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4465,7 +4621,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4488,7 +4644,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4510,7 +4666,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4546,7 +4702,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4568,7 +4724,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4580,7 +4736,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -4589,27 +4745,32 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) vmovups( ( rax, r10, 1 ), ymm3 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA4( 8, 9, 10, 20 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA4( 14, 15, 16, 26 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA4( 17, 18, 19, 29 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -4621,19 +4782,28 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 je( .POST_ACCUM ) label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) vmovss( ( rax, r10, 1 ), xmm3 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA4( 8, 9, 10, 20 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA4( 11, 12, 13, 23 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA4( 14, 15, 16, 26 ) @@ -4656,12 +4826,12 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 // Accumulating & storing the results when beta != 0 - label( .POST_ACCUM_STOR ) + label( .POST_ACCUM_STOR ) ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -4687,7 +4857,7 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -4819,7 +4989,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -4827,7 +4997,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b @@ -4853,7 +5023,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4874,7 +5044,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4895,7 +5065,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4916,7 +5086,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4951,7 +5121,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4972,7 +5142,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA3( 11, 12, 13 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA3( 14, 15, 16 ) @@ -4984,7 +5154,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -4993,26 +5163,31 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) add( imm( 8*4 ), rax ) - // load column from B - vmovups( ( rbx ), zmm6 ) + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro + vmovups( ( rbx ), ymm6 ) VFMA3( 8, 9, 10 ) - vmovups( ( rbx, r9, 1 ), zmm6 ) + vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA3( 11, 12, 13 ) - - vmovups( ( rbx, r9, 2 ), zmm6 ) + + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA3( 14, 15, 16 ) - vmovups( ( rbx, r13, 1 ), zmm6 ) + vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA3( 17, 18, 19 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -5025,18 +5200,27 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA3( 8, 9, 10 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA3( 11, 12, 13 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA3( 14, 15, 16 ) @@ -5064,7 +5248,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -5082,7 +5266,7 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -5207,7 +5391,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -5215,10 +5399,10 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG mov( var( k_iter64 ), rsi ) // load k_iter @@ -5240,7 +5424,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5260,7 +5444,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5281,7 +5465,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5301,7 +5485,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5335,7 +5519,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5355,7 +5539,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA2( 14, 15 ) @@ -5367,7 +5551,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -5376,25 +5560,30 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA2( 8, 9 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA2( 11, 12 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA2( 14, 15 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA2( 17, 18 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -5407,17 +5596,26 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA2( 8, 9 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA2( 11, 12 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA2( 14, 15 ) @@ -5444,7 +5642,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 ) ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) @@ -5460,7 +5658,7 @@ void bli_sgemmsup_rd_zen_asm_2x32_avx512 ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 ) ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) @@ -5575,7 +5773,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -5583,10 +5781,10 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG @@ -5608,7 +5806,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5627,7 +5825,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5646,7 +5844,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5665,7 +5863,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5698,7 +5896,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5717,7 +5915,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA1( 14 ) @@ -5729,7 +5927,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -5738,24 +5936,29 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA1( 8 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA1( 11 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA1( 14 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA1( 17 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -5768,16 +5971,25 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA1( 8 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA1( 11 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA1( 14 ) @@ -5802,7 +6014,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 label( .POST_ACCUM_STOR ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha @@ -5816,7 +6028,7 @@ void bli_sgemmsup_rd_zen_asm_1x32_avx512 label( .POST_ACCUM_STOR_BZ ) ZMM_TO_YMM( 8, 11, 14, 17, 4, 7, 10, 13 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ALPHA_SCALE1 // Scaling the result of A*B with alpha diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h index 80e43843cc..c76ca5dc1a 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #define BLIS_ASM_SYNTAX_ATT diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64m.c b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64m.c index 1e0ce1c4c4..746dc8f102 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64m.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" @@ -208,7 +210,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 lea( mem( , r15, 1 ), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea(mem( , r15, 1 ), rsi) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -220,7 +222,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 lea( mem( r12 ), rcx ) // load c to rcx lea( mem( r14 ), rax ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a @@ -249,7 +251,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -273,7 +275,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -297,7 +299,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -321,7 +323,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -359,7 +361,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -383,7 +385,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -395,7 +397,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -404,7 +406,9 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -413,20 +417,23 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovups( ( rax, rdi, 1 ), ymm5 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA6( 17, 18, 19, 29, 30, 31 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -439,7 +446,11 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -448,12 +459,17 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 vmovss( ( rax, rdi, 1 ), xmm5 ) add( imm( 1*4 ), rax ) + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -518,12 +534,12 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 // Accumulating & storing the results when beta == 0 - label( .POST_ACCUM_STOR_BZ ) + label( .POST_ACCUM_STOR_BZ ) ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -626,7 +642,7 @@ void bli_sgemmsup_rd_zen_asm_6x64m_avx512 alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); - } + } if ( 4 == m_left ) { const dim_t mr_cur = 4; @@ -736,7 +752,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 lea( mem( , r15, 1), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -748,10 +764,10 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 lea( mem( r14 ), rax ) // load c to rcx lea( mem( r12 ), rcx ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG mov( var( k_iter64 ), rsi ) // load k_iter @@ -777,7 +793,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -801,7 +817,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -825,7 +841,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -849,7 +865,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -867,7 +883,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 test( rsi, rsi ) je( .CONSIDER_K_ITER_8 ) - + label( .K_LOOP_ITER32 ) // ITER 0 @@ -886,7 +902,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -910,7 +926,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -922,7 +938,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -931,7 +947,9 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -940,20 +958,23 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovups( ( rax, rdi, 1 ), ymm5 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA6( 17, 18, 19, 29, 30, 31 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -966,7 +987,11 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -975,12 +1000,17 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 vmovss( ( rax, rdi, 1 ), xmm5 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1018,7 +1048,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + // Accumulates the results by horizontally adding the YMM registers, // and having the final result in xmm registers. ACCUM_YMM( 4, 7, 10, 13, 4 ) @@ -1050,7 +1080,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -1153,7 +1183,7 @@ void bli_sgemmsup_rd_zen_asm_6x48m_avx512 alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); - } + } if ( 4 == m_left ) { const dim_t mr_cur = 4; @@ -1263,7 +1293,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 lea( mem( , r15, 1), rsi ) imul( imm( 1*4 ), rsi ) lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c - + lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj; imul( r9, rsi ) // rsi *= cs_b; lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b; @@ -1275,10 +1305,10 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 lea( mem( r14 ), rax ) // load c to rcx lea( mem( r12 ), rcx ) // load a to rax lea( mem( rdx ), rbx ) // load b to rbx - + lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b - + INIT_REG mov( var( k_iter64 ), rsi ) // load k_iter @@ -1304,7 +1334,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1328,7 +1358,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1352,7 +1382,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1376,7 +1406,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1413,7 +1443,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1437,7 +1467,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rbx, r9, 1 ), zmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), zmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1449,7 +1479,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 dec( rsi ) jne( .K_LOOP_ITER32 ) - + label( .CONSIDER_K_ITER_8 ) mov( var( k_iter8 ), rsi ) test( rsi, rsi ) @@ -1458,7 +1488,9 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 label( .K_LOOP_ITER8 ) // ITER 0 - // load row from A + // Load row from A using ymm registers + // Upper 256-bit lanes are cleared for the + // zmm counterpart vmovups( ( rax ), ymm0 ) vmovups( ( rax, r8, 1 ), ymm1 ) vmovups( ( rax, r8, 2 ), ymm2 ) @@ -1467,20 +1499,23 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovups( ( rax, rdi, 1 ), ymm5 ) add( imm( 8*4 ), rax ) - // load column from B + // Load column from B using ymm registers + // Upper 256-bit lane is cleared for the + // zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovups( ( rbx ), ymm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovups( ( rbx, r9, 1 ), ymm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovups( ( rbx, r9, 2 ), ymm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) vmovups( ( rbx, r13, 1 ), ymm6 ) VFMA6( 17, 18, 19, 29, 30, 31 ) - add( imm( 8*4 ), rbx ) + add( imm( 8*4 ), rbx ) dec( rsi ) jne( .K_LOOP_ITER8 ) @@ -1493,7 +1528,11 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 label( .K_LOOP_LEFT1 ) - + + // Load row from A using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart vmovss( ( rax ), xmm0 ) vmovss( ( rax, r8, 1 ), xmm1 ) vmovss( ( rax, r8, 2 ), xmm2 ) @@ -1502,12 +1541,17 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 vmovss( ( rax, rdi, 1 ), xmm5 ) add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4; + // Load column from B using xmm registers + // Upper 256-bit lanes and the upper 224 + // bits of the lower 256-bit lane are cleared + // for the zmm counterpart + // Thus, we can re-use the VFMA6 macro vmovss( ( rbx ), xmm6 ) VFMA6( 8, 9, 10, 20, 21, 22 ) vmovss( ( rbx, r9, 1 ), xmm6 ) VFMA6( 11, 12, 13, 23, 24, 25 ) - + vmovss( ( rbx, r9, 2 ), xmm6 ) VFMA6( 14, 15, 16, 26, 27, 28 ) @@ -1544,7 +1588,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + // Accumulates the results by horizontally adding the YMM registers, // and having the final result in xmm registers. ACCUM_YMM( 4, 7, 10, 13, 4 ) @@ -1576,7 +1620,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 ZMM_TO_YMM( 8, 9, 10, 11, 4, 5, 6, 7 ) ZMM_TO_YMM( 12, 13, 14, 15, 8, 9, 10, 11 ) ZMM_TO_YMM( 16, 17, 18, 19, 12, 13, 14, 15 ) - + ACCUM_YMM( 4, 7, 10, 13, 4 ) ACCUM_YMM( 5, 8, 11, 14, 5 ) ACCUM_YMM( 6, 9, 12, 15, 6 ) @@ -1680,7 +1724,7 @@ void bli_sgemmsup_rd_zen_asm_6x32m_avx512 alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); - } + } if ( 4 == m_left ) { const dim_t mr_cur = 4; diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64n.c b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64n.c index 145d3b5201..c8de9bf1cc 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64n.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64n.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_d8x8m.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_d8x8m.c new file mode 100644 index 0000000000..fdacd7c9ba --- /dev/null +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_d8x8m.c @@ -0,0 +1,1305 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#include "blis.h" +#include "immintrin.h" + +#if defined __clang__ + #define UNROLL_LOOP() _Pragma("clang loop unroll_count(4)") + /* + * in clang, unroll_count(4) generates inefficient + * code compared to unroll(full) when loopCount = 4. + */ + #define UNROLL_LOOP_FULL() _Pragma("clang loop unroll(full)") +#elif defined __GNUC__ + #define UNROLL_LOOP() _Pragma("GCC unroll 4") + #define UNROLL_LOOP_FULL() _Pragma("GCC unroll 8") +#else + #define UNROLL_LOOP() + #define UNROLL_LOOP_FULL() +#endif + +#define ZERO_REGISTERS() \ + c_reg[0] = _mm512_setzero_pd(); \ + c_reg[1] = _mm512_setzero_pd(); \ + c_reg[2] = _mm512_setzero_pd(); \ + c_reg[3] = _mm512_setzero_pd(); \ + c_reg[4] = _mm512_setzero_pd(); \ + c_reg[5] = _mm512_setzero_pd(); \ + c_reg[6] = _mm512_setzero_pd(); \ + c_reg[7] = _mm512_setzero_pd(); \ + +#define TRANSPOSE_8x8() \ + a_reg[0] = _mm512_unpacklo_pd(c_reg[0], c_reg[1]); \ + a_reg[1] = _mm512_unpacklo_pd(c_reg[2], c_reg[3]); \ + a_reg[2] = _mm512_unpacklo_pd(c_reg[4], c_reg[5]); \ + a_reg[3] = _mm512_unpacklo_pd(c_reg[6], c_reg[7]); \ + a_reg[4] = _mm512_unpackhi_pd(c_reg[0], c_reg[1]); \ + a_reg[5] = _mm512_unpackhi_pd(c_reg[2], c_reg[3]); \ + a_reg[6] = _mm512_unpackhi_pd(c_reg[4], c_reg[5]); \ + a_reg[7] = _mm512_unpackhi_pd(c_reg[6], c_reg[7]); \ + /*Stage2*/ \ + b_reg[0] = _mm512_shuffle_f64x2(a_reg[0], a_reg[1], 0b10001000); \ + b_reg[1] = _mm512_shuffle_f64x2(a_reg[2], a_reg[3], 0b10001000); \ + /*Stage3 1,5*/ \ + c_reg[0] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b10001000); \ + c_reg[4] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b11011101); \ + /*Stage2*/ \ + b_reg[0] = _mm512_shuffle_f64x2(a_reg[0], a_reg[1], 0b11011101); \ + b_reg[1] = _mm512_shuffle_f64x2(a_reg[2], a_reg[3], 0b11011101); \ + /*Stage3 3,7*/ \ + c_reg[2] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b10001000); \ + c_reg[6] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b11011101); \ + /*Stage2*/ \ + b_reg[0] = _mm512_shuffle_f64x2(a_reg[4], a_reg[5], 0b10001000); \ + b_reg[1] = _mm512_shuffle_f64x2(a_reg[6], a_reg[7], 0b10001000); \ + /*Stage3 2,6*/ \ + c_reg[1] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b10001000); \ + c_reg[5] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b11011101); \ + /*Stage2*/ \ + b_reg[0] = _mm512_shuffle_f64x2(a_reg[4], a_reg[5], 0b11011101); \ + b_reg[1] = _mm512_shuffle_f64x2(a_reg[6], a_reg[7], 0b11011101); \ + /*Stage3 4,8*/ \ + c_reg[3] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b10001000); \ + c_reg[7] = _mm512_shuffle_f64x2(b_reg[0], b_reg[1], 0b11011101); + +#define GEMM_MxN(M, N) \ + UNROLL_LOOP() \ + for (dim_t j = 0; j < k; ++j) \ + { \ + b_reg[0] = _mm512_mask_loadu_pd(c_reg[0], mask_n, b_curr); \ + b_curr += rs_b; \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + a_reg[ii] = _mm512_set1_pd(*( a_curr + (rs_a * ii) )); \ + c_reg[ii] = _mm512_fmadd_pd(a_reg[ii] , b_reg[0], c_reg[ii]); \ + } \ + a_curr += cs_a; \ + } \ + + +#define STORE_COL(M, N) \ + if ((*beta) == 0) { STORE_COL_BZ(M, N) } \ + else \ + { \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], (1 << (M)) - 1, c + cs_c * ii); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, (1 << (M)) - 1, c_reg[ii]); \ + } \ + } \ + +#define STORE_COL_BZ(M, N) \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, (1 << (M)) - 1, c_reg[ii]); \ + } \ + +#define STORE_COL_LOWER(M, N) \ + if ((*beta) == 0) { STORE_COL_LOWER_BZ(M, N) } \ + else \ + { \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], ((1 << (n_rem - ii)) -1) << ii, c + cs_c * ii); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, ((1 << (n_rem - ii)) -1) << ii, c_reg[ii]); \ + } \ + } \ + +#define STORE_COL_LOWER_BZ(M, N) \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, ((1 << (n_rem - ii)) -1) << ii, c_reg[ii]); \ + } \ + +#define STORE_COL_UPPER(M, N) \ + if ((*beta) == 0) { STORE_COL_UPPER_BZ(M, N) } \ + else \ + { \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], (1 << (ii+1)) - 1, c + cs_c * ii); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, (1 << (ii+1)) - 1, c_reg[ii]); \ + } \ + } \ + +#define STORE_COL_UPPER_BZ(M, N) \ + TRANSPOSE_8x8() \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + cs_c * ii, (1 << (ii+1)) - 1, c_reg[ii]); \ + } \ + + +#define STORE_ROW(M, N) \ + if ((*beta) == 0) { STORE_ROW_BZ(M, N) } \ + else \ + { \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], mask_n, c + (rs_c * ii)); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + } \ + +#define STORE_ROW_BZ(M, N) \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + +#define STORE_ROW_LOWER(M, N) \ + if ((*beta) == 0) { STORE_ROW_LOWER_BZ(M, N) } \ + else \ + { \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], (1 << (ii+1)) - 1, c + (rs_c * ii)); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), (1 << (ii+1)) - 1, c_reg[ii]); \ + } \ + } \ + +#define STORE_ROW_LOWER_BZ(M, N) \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), (1 << (ii+1)) - 1, c_reg[ii]); \ + } \ + +#define STORE_ROW_UPPER(M, N) \ + if ((*beta) == 0) { STORE_ROW_UPPER_BZ(M, N) } \ + else \ + { \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + b_reg[1] = _mm512_set1_pd(*(beta)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + a_reg[ii] = _mm512_mask_loadu_pd(c_reg[ii], ((1 << (n_rem - ii)) - 1) << ii, c + (rs_c * ii)); \ + c_reg[ii] = _mm512_fmadd_pd(b_reg[1], a_reg[ii], c_reg[ii]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), ((1 << (n_rem - ii)) - 1) << ii, c_reg[ii]); \ + } \ + } \ + +#define STORE_ROW_UPPER_BZ(M, N) \ + b_reg[0] = _mm512_set1_pd(*(alpha)); \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], b_reg[0]); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), ((1 << (n_rem - ii)) - 1) << ii, c_reg[ii]); \ + } \ + +#define MAIN_LOOP(M) \ + n_rem = n % 8; \ + if (n_rem == 0) n_rem = 8; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem)) - 1; \ + GEMM_MxN(M, n_rem) \ + if (cs_c == 1) { STORE_ROW(M, n_rem) } \ + else { STORE_COL(M, n_rem) } \ + c += 8 * rs_c; \ + +#define MAIN_LOOP_LOWER_DIAG(M) \ + n_rem = n % 8; \ + if (n_rem == 0) n_rem = 8; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem)) - 1; \ + GEMM_MxN(M, n_rem) \ + if (cs_c == 1) { STORE_ROW_LOWER(M, n_rem) } \ + else { STORE_COL_LOWER(M, n_rem) } \ + c += 8 * rs_c; \ + +#define MAIN_LOOP_UPPER_DIAG(M) \ + n_rem = n % 8; \ + if (n_rem == 0) n_rem = 8; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem)) - 1; \ + GEMM_MxN(M, n_rem) \ + if (cs_c == 1) { STORE_ROW_UPPER(M, n_rem) } \ + else { STORE_COL_UPPER(M, n_rem) } \ + c += 8 * rs_c; \ + +void bli_dgemmsup_rv_zen4_asm_8x8m + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[8]; + __m512d a_reg[8]; + __m512d b_reg[2]; + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 8; + dim_t m_rem = m % 8; + double *a_curr = a, *b_curr, *c = c_; + dim_t i =0; + for (i = 0; i < m_main; i++) + { + MAIN_LOOP(8); + } + switch (m_rem) + { + case 1: + MAIN_LOOP(1); break; + case 2: + MAIN_LOOP(2); break; + case 3: + MAIN_LOOP(3); break; + case 4: + MAIN_LOOP(4); break; + case 5: + MAIN_LOOP(5); break; + case 6: + MAIN_LOOP(6); break; + case 7: + MAIN_LOOP(7); break; + } +} + +void bli_dgemmsup_rv_zen4_asm_8x8m_lower + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[8]; + __m512d a_reg[8]; + __m512d b_reg[2]; + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 8; + dim_t m_rem = m % 8; + double *a_curr = a, *b_curr, *c = c_; + dim_t i = 0; + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_LOWER_DIAG(8); + } + switch (m_rem) + { + case 1: + MAIN_LOOP_LOWER_DIAG(1); break; + case 2: + MAIN_LOOP_LOWER_DIAG(2); break; + case 3: + MAIN_LOOP_LOWER_DIAG(3); break; + case 4: + MAIN_LOOP_LOWER_DIAG(4); break; + case 5: + MAIN_LOOP_LOWER_DIAG(5); break; + case 6: + MAIN_LOOP_LOWER_DIAG(6); break; + case 7: + MAIN_LOOP_LOWER_DIAG(7); break; + } +} + +void bli_dgemmsup_rv_zen4_asm_8x8m_upper + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[8]; + __m512d a_reg[8]; + __m512d b_reg[2]; + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 8; + dim_t m_rem = m % 8; + double *a_curr = a, *b_curr, *c = c_; + dim_t i = 0; + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_UPPER_DIAG(8); + } + switch (m_rem) + { + case 1: + MAIN_LOOP_UPPER_DIAG(1); break; + case 2: + MAIN_LOOP_UPPER_DIAG(2); break; + case 3: + MAIN_LOOP_UPPER_DIAG(3); break; + case 4: + MAIN_LOOP_UPPER_DIAG(4); break; + case 5: + MAIN_LOOP_UPPER_DIAG(5); break; + case 6: + MAIN_LOOP_UPPER_DIAG(6); break; + case 7: + MAIN_LOOP_UPPER_DIAG(7); break; + } +} + +/* + 8x8 lower triangular DGEMMT kernel + This kernels expects M <= 8; + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |*-------| + |**------| + |***-----| + |****----| + |*****---| + |******--| + |*******-| + |********| + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_8x8m_lower_mle8 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[8]; + __m512d a_reg[8]; + __m512d b_reg[2]; + __mmask8 mask_n; + dim_t n_rem; + dim_t m_rem = m % 8; + double *a_curr = a, *b_curr, *c = c_; + dim_t i = 0; + if (m == 8) + { + MAIN_LOOP_LOWER_DIAG(8); + } + switch (m_rem) + { + case 1: + MAIN_LOOP_LOWER_DIAG(1); break; + case 2: + MAIN_LOOP_LOWER_DIAG(2); break; + case 3: + MAIN_LOOP_LOWER_DIAG(3); break; + case 4: + MAIN_LOOP_LOWER_DIAG(4); break; + case 5: + MAIN_LOOP_LOWER_DIAG(5); break; + case 6: + MAIN_LOOP_LOWER_DIAG(6); break; + case 7: + MAIN_LOOP_LOWER_DIAG(7); break; + } +} + +/* + 8x8 Upper triangular DGEMMT kernel + This kernels expects M <= 8; + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |********| + |-*******| + |--******| + |---*****| + |----****| + |-----***| + |------**| + |-------*| + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_8x8m_upper_mle8 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[8]; + __m512d a_reg[8]; + __m512d b_reg[2]; + __mmask8 mask_n; + dim_t n_rem; + // dim_t m_main = m / 8; + dim_t m_rem = m % 8; + double *a_curr = a, *b_curr, *c = c_; + dim_t i = 0; + // for (i = 0; i < m_main; i++) + if (m == 8) + { + MAIN_LOOP_UPPER_DIAG(8); + } + switch (m_rem) + { + case 1: + MAIN_LOOP_UPPER_DIAG(1); break; + case 2: + MAIN_LOOP_UPPER_DIAG(2); break; + case 3: + MAIN_LOOP_UPPER_DIAG(3); break; + case 4: + MAIN_LOOP_UPPER_DIAG(4); break; + case 5: + MAIN_LOOP_UPPER_DIAG(5); break; + case 6: + MAIN_LOOP_UPPER_DIAG(6); break; + case 7: + MAIN_LOOP_UPPER_DIAG(7); break; + } +} + +/* + The diagonal pattern repeats after every block of + size 24x24, therefore three 24x8 kernels are added to + make sure that entire 24x24 block gets covered. + + Diagram for Lower traingular 24x24 block + + lower_0 lower_1 lower_2 + ________ ________ ________ + |*-------|--------|--------| + |**------|--------|--------| + |***-----|--------|--------| + |****----|--------|--------| + |*****---|--------|--------| + |******--|--------|--------| + |*******-|--------|--------| + |********|--------|--------| + ________ ________ ________ + |********|*-------|--------| + |********|**------|--------| + |********|***-----|--------| + |********|****----|--------| + |********|*****---|--------| + |********|******--|--------| + |********|*******-|--------| + |********|********|--------| + ________ ________ ________ + |********|********|*-------| + |********|********|**------| + |********|********|***-----| + |********|********|****----| + |********|********|*****---| + |********|********|******--| + |********|********|*******-| + |********|********|********| + ________ ________ ________ +*/ + +/* + 24x8 Lower traingular kernel, which computes the + first 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |*-------| < + |**------| | + |***-----| | + |****----| intial 8x8 triangular panel + |*****---| | + |******--| | + |*******-| > + ________ + |********| < + |********| | + |********| | + |********| | + |********| + |********| + |********| 16x8 full GEMM panel + |********| + |********| + |********| + |********| + |********| | + |********| | + |********| | + |********| > + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_lower_0 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag; // m for traingular kernel + dim_t m_full; // m for full GEMM kernel + // if m <= 8 then only diagonal region needs to be + // computed, therefor set m_full to 0. + if (m <= 8) + { + // if m <= 8, m_diag = 8 , m_full = 0 + m_diag = m; + m_full = 0; + } + // if m > 8, then full diagonal(m=8) needs to be computed + // and remaning m (m - 8) will be computed by DGEMM SUP kernel. + else + { + m_diag = 8; + m_full = m - 8; + } + + // since the 8x8m kernel is row major, + // call row major 8x8m upper diagonal kernel after + // inducing transpose to solve column major lower + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_upper_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a, cs_a, rs_a, + beta, + c_, cs_c, rs_c, + data, + cntx + ); + + // call full GEMM kernel for remaning parts of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a + (rs_a * m_diag), rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_ + (rs_c * m_diag), rs_c, cs_c, + data, + cntx + ); +} + +/* + 24x8 Lower traingular kernel, which computes the + second 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |--------| < + |--------| | + |--------| | + |--------| intial empty 8x8 panel + |--------| | + |--------| | + |--------| > + ________ + |*-------| < + |**------| | + |***-----| | + |****----| 8x8 triangular panel + |*****---| | + |******--| | + |*******-| > + ________ + |********| < + |********| | + |********| | + |********| | + |********| 8x8 full GEMM panel + |********| | + |********| | + |********| > + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_lower_1 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag; // m for traingular kernel + dim_t m_full; // m for full GEMM kenrel + + // if m is less than 8, then only empty region is computed + // therefore set m_diag and m_full to 0. + if (m <= 8) + { + m_diag = 0; + m_full = 0; + } + // if m_diag is less than 16, then only empty region and triangular + // region needs to be computed, therefor set m_full to 0. + else if ( m <= 16) + { + m_diag = m - 8; + m_full = 0; + } + else + { + m_diag = 8; + m_full = m - 16; + } + + // since the 8x8m kernel is row major, + // call row major 8x8m upper diagonal kernel after + // inducing transpose to solve column major lower + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_upper_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a + (rs_a * 8), cs_a, rs_a, + beta, + c_ + (rs_c * 8), cs_c, rs_c, + data, + cntx + ); + + // call full GEMM kernel for remaning parts of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a + (rs_a*(8+m_diag)), rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_ + (rs_c * (8+m_diag)), rs_c, cs_c, + data, + cntx + ); +} + +/* + 24x8 Lower traingular kernel, which computes the + third 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |--------| < + |--------| | + |--------| | + |--------| | + |--------| | + |--------| | + |--------| | + |--------| | + |--------| intial empty 16x8 panel + |--------| | + |--------| | + |--------| | + |--------| | + |--------| | + |--------| | + |--------| > + ________ + |*-------| < + |**------| | + |***-----| | + |****----| 8x8 triangular panel + |*****---| | + |******--| | + |*******-| > + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_lower_2 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag; // m for traingular kernel + dim_t m_full; // m for full GEMM kernel + + // if m <= 16, only empty region needs to be computed. + if (m <= 16) + { + m_diag = 0; + m_full = 0; + } + + // if m <= 24, initial 16 rows are empty and there is no full + // gemm region, therefore m_diag = 0 + else if (m <= 24) + { + m_diag = m - 16; + m_full = 0; + } + else + { + m_diag = 8; + m_full = m - 24; // m - (16(empty) + 8(diagonal)) + } + + // since the 8x8m kernel is row major, + // call row major 8x8m upper diagonal kernel after + // inducing transpose to solve column major lower + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_upper_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a + (rs_a * 16), cs_a, rs_a, + beta, + c_ + (rs_c * 16), cs_c, rs_c, + data, + cntx + ); + + // call full GEMM kernel for remaning parts of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a + (rs_a*(16+m_diag)), rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_ + (rs_c * (16+m_diag)), rs_c, cs_c, + data, + cntx + ); +} + +/* + The diagonal pattern repeats after every block of + size 24x24, therefore three 24x8 kernels are added to + make sure that entire 24x24 block gets covered. + + Diagram for Upper traingular 24x24 block + + upper_0 upper_1 upper_2 + ________ ________ ________ + |********|********|********| + |-*******|********|********| + |--******|********|********| + |---*****|********|********| + |----****|********|********| + |-----***|********|********| + |------**|********|********| + |-------*|********|********| + ________ ________ ________ + |--------|********|********| + |--------|-*******|********| + |--------|--******|********| + |--------|---*****|********| + |--------|----****|********| + |--------|-----***|********| + |--------|------**|********| + |--------|-------*|********| + ________ ________ ________ + |--------|--------|********| + |--------|--------|-*******| + |--------|--------|--******| + |--------|--------|---*****| + |--------|--------|----****| + |--------|--------|-----***| + |--------|--------|------**| + |--------|--------|-------*| + ________ ________ ________ + +*/ + +/* + 24x8 Upper traingular kernel, which computes the + first 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |********| < + |-*******| | + |--******| | + |---*****| intial 8x8 triangular block + |----****| | + |-----***| | + |------**| | + |-------*| > + ________ + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_upper_0 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag; // m for traingular kernel + dim_t m_full; // m for full GEMM kenrel + + // if m <= 8, then only diagonal region exists + // therefore m_full = 0 + if (m <= 8) + { + m_diag = m; + m_full = 0; + } + + // if m >= 8, then initial 8 rows are computed + // by DGEMM SUP kernel, and last 16 rows are empty + else if (m <= 24) + { + m_diag = 8; + m_full = 0; + } + // if m > 24, then compute inital 24 rows with existing + // logic and use DGEMM SUP kernel for remainder. + else + { + m_diag = 8; + m_full = m - 24; // m - (16(empty) + 8(diagonal)) + } + + // call full GEMM kernel for intial part of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx + ); + + // since the 8x8m kernel is row major, + // call row major 8x8m lower diagonal kernel after + // inducing transpose to solve column major upper + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_lower_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a + (rs_a*m_full), cs_a, rs_a, + beta, + c_ + (rs_c * m_full), cs_c, rs_c, + data, + cntx + ); +} + +/* + 24x8 Upper traingular kernel, which computes the + second 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |********| < + |********| | + |********| | + |********| 8x8 full GEMM block + |********| | + |********| | + |********| | + |********| > + ________ + |********| < + |-*******| | + |--******| | + |---*****| 8x8 triangular block + |----****| | + |-----***| | + |------**| | + |-------*| > + ________ + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + |--------| + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_upper_1 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag, m_full; + if (m <= 8) + { + m_diag = m; + m_full = 0; + } + else if (m <= 16) + { + m_diag = 8; + m_full = 0; + } + else + { + m_diag = 8; + m_full = m - 16; + } + + // call full GEMM kernel for intial part of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx + ); + + // since the 8x8m kernel is row major, + // call row major 8x8m lower diagonal kernel after + // inducing transpose to solve column major upper + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_lower_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a + (rs_a*m_full), cs_a, rs_a, + beta, + c_ + (rs_c * m_full), cs_c, rs_c, + data, + cntx + ); +} + +/* + 24x8 Upper traingular kernel, which computes the + second 24x8 micro panel of the 24x24 repeating block + + Region marked by '*' is computed by this kernel + Region marked by '-' is not computed. + ________ + |********| < + |********| | + |********| | + |********| | + |********| | + |********| | + |********| | + |********| 16x8 full GEMM block + |********| | + |********| | + |********| | + |********| | + |********| | + |********| | + |********| | + |********| > + ________ + |********| < + |-*******| | + |--******| | + |---*****| 8x8 triangular block + |----****| | + |-----***| | + |------**| | + |-------*| > + ________ +*/ +void bli_dgemmsup_rv_zen4_asm_24x8m_upper_2 + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, inc_t rs_a, inc_t cs_a, + double* restrict b, inc_t rs_b, inc_t cs_b, + double* restrict beta, + double* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + dim_t m_diag, m_full; + if (m <= 8) + { + m_diag = m; + m_full = 0; + } + else + { + m_diag = 8; + m_full = m - 8; + } + + // call full GEMM kernel for intial part of matrix + bli_dgemmsup_rv_zen4_asm_24x8m + ( + conja, + conjb, + m_full, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx + ); + + // since the 8x8m kernel is row major, + // call row major 8x8m lower diagonal kernel after + // inducing transpose to solve column major upper + // triangular GEMM + bli_dgemmsup_rv_zen4_asm_8x8m_lower_mle8 + ( + conjb, + conja, + n, + m_diag, + k, + alpha, + b, cs_b, rs_b, + a + (rs_a*m_full), cs_a, rs_a, + beta, + c_ + (rs_c * m_full), cs_c, rs_c, + data, + cntx + ); +} diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_z4x4m.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_z4x4m.c new file mode 100644 index 0000000000..f4de53b978 --- /dev/null +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen4_asm_z4x4m.c @@ -0,0 +1,1147 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#include "blis.h" +#include "immintrin.h" + +#if defined __clang__ + #define UNROLL_LOOP() _Pragma("clang loop unroll_count(4)") + /* + * in clang, unroll_count(4) generates inefficient + * code compared to unroll(full) when loopCount = 4. + */ + #define UNROLL_LOOP_FULL() _Pragma("clang loop unroll(full)") +#elif defined __GNUC__ + #define UNROLL_LOOP() _Pragma("GCC unroll 4") + #define UNROLL_LOOP_FULL() _Pragma("GCC unroll 4") +#else + #define UNROLL_LOOP() + #define UNROLL_LOOP_FULL() +#endif + +/*Set registers to zero which are used during fma operation*/ +#define ZERO_REGISTERS() \ + c_reg[0] = _mm512_setzero_pd(); \ + c_reg[1] = _mm512_setzero_pd(); \ + c_reg[2] = _mm512_setzero_pd(); \ + c_reg[3] = _mm512_setzero_pd(); \ + c_imag_reg[0] = _mm512_setzero_pd(); \ + c_imag_reg[1] = _mm512_setzero_pd(); \ + c_imag_reg[2] = _mm512_setzero_pd(); \ + c_imag_reg[3] = _mm512_setzero_pd(); \ + +/*************************************************************/ +/* Transpose contents of R0, R1, R2, R3 and store */ +/* the result to same register */ +/* Transpose 4x4 register */ +/* Input c_reg0 = Ar0 Ai0 Ar1 Ai1 Ar2 Ai2 Ar3 Ai3 */ +/* Input c_reg1 = Ar4 Ai4 Ar5 Ai5 Ar6 Ai6 Ar7 Ai7 */ +/* Input c_reg2 = Ar8 Ai8 Ar9 Ai9 Ar10 Ai10 Ar11 Ai11 */ +/* Input c_reg3 = Ar12 Ai12 Ar13 Ai13 Ar14 Ai14 Ar15 Ai15 */ +/* Inter c_imag_reg0 = Ar0 Ai0 Ar2 Ai2 Ar4 Ai4 Ar6 Ai6 */ +/* Inter c_imag_reg1 = Ar1 Ai1 Ar3 Ai3 Ar5 Ai5 Ar7 Ai7 */ +/* Inter c_imag_reg2 = Ar8 Ai8 Ar10 Ai10 Ar12 Ai12 Ar14 Ai14 */ +/* Inter c_imag_reg3 = Ar9 Ai9 Ar11 Ai11 Ar13 Ai13 Ar15 Ai15 */ +/* Output c_reg0 = Ar0 Ai0 Ar4 Ai4 Ar8 Ai8 Ar12 Ai12 */ +/* Output c_reg1 = Ar1 Ai1 Ar5 Ai5 Ar9 Ai9 Ar13 Ai13 */ +/* Output c_reg2 = Ar2 Ai2 Ar6 Ai6 Ar10 Ai10 Ar14 Ai14 */ +/* Output c_reg3 = Ar3 Ai3 Ar7 Ai7 Ar11 Ai11 Ar15 Ai15 */ +/*************************************************************/ +#define TRANSPOSE_4x4() \ + c_imag_reg[0] = _mm512_shuffle_f64x2(c_reg[0], c_reg[1], 0b10001000); \ + c_imag_reg[1] = _mm512_shuffle_f64x2(c_reg[0], c_reg[1], 0b11011101); \ + c_imag_reg[2] = _mm512_shuffle_f64x2(c_reg[2], c_reg[3], 0b10001000); \ + c_imag_reg[3] = _mm512_shuffle_f64x2(c_reg[2], c_reg[3], 0b11011101); \ + c_reg[0] = _mm512_shuffle_f64x2(c_imag_reg[0], c_imag_reg[2], 0b10001000); \ + c_reg[2] = _mm512_shuffle_f64x2(c_imag_reg[0], c_imag_reg[2], 0b11011101); \ + c_reg[1] = _mm512_shuffle_f64x2(c_imag_reg[1], c_imag_reg[3], 0b10001000); \ + c_reg[3] = _mm512_shuffle_f64x2(c_imag_reg[1], c_imag_reg[3], 0b11011101); + +/****************************************/ +/* Operation: */ +/* c_reg = A(real) * B(real,imag) */ +/* c_imag_reg = A(imag) * B(real,imag) */ +/* Elements: */ +/* MxK elements at a time */ +/* Inputs: */ +/* b_reg = b_curr */ +/* a_reg = a_curr->real */ +/* a_reg = a_curr->imag */ +/* Outputs: */ +/* c_reg = b_reg * a_curr->real */ +/* c_imag_reg = b_reg * a_curr->imag */ +/****************************************/ +#define GEMM_MxN(M,N) \ + UNROLL_LOOP() \ + for (dim_t j = 0; j < k; ++j) \ + { \ + b_reg = _mm512_maskz_loadu_pd(mask_n, b_curr); \ + b_curr += rs_b; \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + a_reg[ii] = _mm512_set1_pd(*( (double*)(a_curr + (rs_a * ii) ))); \ + c_reg[ii] = _mm512_fmadd_pd(a_reg[ii] , b_reg, c_reg[ii]); \ + a_reg[ii] = _mm512_set1_pd((a_curr + (rs_a * ii))->imag); \ + c_imag_reg[ii] = _mm512_fmadd_pd(a_reg[ii] , b_reg, c_imag_reg[ii]); \ + } \ + a_curr += cs_a; \ + } + +/****************************************/ +/* Store elements in col order */ +/* c_reg = Beta * C + Alpha * A * B */ +/* Elements: */ +/* MxN elements at a time */ +/* Inputs: */ +/* c_reg = b_reg * a_curr->real */ +/* c_imag_reg = b_reg * a_curr->imag */ +/* Intermediate: */ +/* c_reg = c_reg +/- c_imag_reg */ +/* Transpose 4x4 elements in c_reg */ +/* Output: */ +/* c_reg = Beta * C(real,imag) + */ +/* Alpha * A(real,imag) * B(real,imag) */ +/****************************************/ +#define STORE_COL(M, N) \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + a_reg[ii] = _mm512_permute_pd(c_imag_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, a_reg[ii]); \ + } \ + TRANSPOSE_4x4() \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_COL_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + SCALE_BETA(mask_n, cs_c) \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and */ +/* store elements in col major order */ +/* where Beta = 0 */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_COL_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Operation: */ +/* 1. Load C register based on the mask */ +/* and scale it with beta */ +/* 2. Scale A*B result with alpha value */ +/* 3. Add results from step1 & step2 */ +/* 4. Transpose and store results in */ +/* in col major order */ +/* 5. Output update is done only for */ +/* lower traingular matrix */ +/* NOTE: */ +/* Mask value is set to 1 if the */ +/* element exist else it is set to 0 */ +/* For m=1, mask = 2 to store real and */ +/* imag component */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Beta * C + */ +/* Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_COL_LOWER(M, N) \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + a_reg[ii] = _mm512_permute_pd(c_imag_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, a_reg[ii]); \ + } \ + TRANSPOSE_4x4() \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_COL_LOWER_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + mask_n = ((1 << ((n_rem*2) - (ii*2))) -1) << (ii*2); \ + SCALE_BETA(mask_n, cs_c) \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and store */ +/* number of elements based on the mask */ +/* in col major order where Beta = 0 */ +/* Output update is done only for */ +/* lower traingular matrix */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_COL_LOWER_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + mask_n = ((1 << ((n_rem*2) - (ii*2))) - 1) << (ii*2); \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Operation: */ +/* 1. Load C register based on the mask */ +/* and scale it with beta */ +/* 2. Scale A*B result with alpha value */ +/* 3. Add results from step1 & step2 */ +/* 4. Transpose and store results in */ +/* in col major order */ +/* 5. Output update is done only for */ +/* upper traingular matrix */ +/* NOTE: */ +/* Mask value is set to 1 if the */ +/* element exist else it is set to 0 */ +/* For m=1, mask = 2 to store real and */ +/* imag component */ +/* Elements: */ +/* MxN elements at a time */ +/* Inputs: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Beta * C + */ +/* Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_COL_UPPER(M, N) \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + a_reg[ii] = _mm512_permute_pd(c_imag_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, a_reg[ii]); \ + } \ + TRANSPOSE_4x4() \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_COL_UPPER_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + mask_n = (1 << ((ii+1)*2)) - 1; \ + SCALE_BETA(mask_n, cs_c) \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and store */ +/* number of elements based on the mask */ +/* in col major order where Beta = 0 */ +/* Output update is done only for */ +/* upper traingular matrix */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Inputs: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_COL_UPPER_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < N; ++ii) \ + { \ + SCALE_ALPHA_COL(M) \ + mask_n = (1 << (((ii+1)*2))) - 1; \ + _mm512_mask_storeu_pd(c + cs_c * ii, mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and */ +/* store elements in row major order */ +/* where Beta = 0 */ +/* Elements: */ +/* Mx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_ROW_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Store elements in row major order */ +/* Elements: */ +/* Mx4 elements at a time */ +/* Inputs: */ +/* c_reg = b_reg * a_curr->real */ +/* c_imag_reg = b_reg * a_curr->imag */ +/* Intermediate: */ +/* c_reg = c_reg +/- c_imag_reg */ +/* Output: */ +/* c_reg = Beta * C(real,imag) + */ +/* Alpha * A(real,imag) * B(real,imag) */ +/****************************************/ +#define STORE_ROW(M, N) \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_ROW_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + SCALE_BETA(mask_n, rs_c) \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Scale A * B matrix with alpha value */ +/* Elements: */ +/* 4 elements at a time */ +/* Inputs: */ +/* c_reg = b_reg * a_curr->real */ +/* c_imag_reg = b_reg * a_curr->imag */ +/* Output: */ +/* c_reg = Alpha * A(real,imag) * */ +/* B(real,imag) */ +/****************************************/ +#define SCALE_ALPHA(M)\ + a_reg[ii] = _mm512_permute_pd(c_imag_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, a_reg[ii]); \ + c_imag_reg[ii] = _mm512_permute_pd(c_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], alpha_reg); \ + c_imag_reg[ii] = _mm512_mul_pd(c_imag_reg[ii], alpha_imag_reg); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, c_imag_reg[ii]); \ + +/****************************************/ +/* Scale A * B matrix with alpha value */ +/* Elements: */ +/* 4 elements at a time */ +/* Input: */ +/* c_reg = A * B */ +/* Output: */ +/* c_reg = Alpha * A(real,imag) * */ +/* B(real,imag) */ +/****************************************/ +#define SCALE_ALPHA_COL(M)\ + c_imag_reg[ii] = _mm512_permute_pd(c_reg[ii], 0b01010101); \ + c_reg[ii] = _mm512_mul_pd(c_reg[ii], alpha_reg); \ + c_imag_reg[ii] = _mm512_mul_pd(c_imag_reg[ii], alpha_imag_reg); \ + c_reg[ii] = _mm512_fmaddsub_pd(c_reg[ii], one_reg, c_imag_reg[ii]); \ + +/****************************************/ +/* Scale C matrix with beta value */ +/* Elements: */ +/* 4 elements at a time */ +/* Mask is set based on M elements */ +/* Output : */ +/* c_reg = Beta * C */ +/****************************************/ +#define SCALE_BETA(mask_n, stride) \ + a_reg[ii] = _mm512_maskz_loadu_pd(mask_n, c + (stride * ii)); \ + c_imag_reg[ii] = _mm512_permute_pd(a_reg[ii], 0b01010101); \ + a_reg[ii] = _mm512_mul_pd(a_reg[ii], beta_reg); \ + c_imag_reg[ii] = _mm512_mul_pd(c_imag_reg[ii], beta_imag_reg); \ + a_reg[ii] = _mm512_fmaddsub_pd(a_reg[ii], one_reg, c_imag_reg[ii]); \ + c_reg[ii] = _mm512_add_pd(a_reg[ii], c_reg[ii]); \ + +/****************************************/ +/* Operation: */ +/* 1. Load C register based on the mask */ +/* and scale it with beta */ +/* 2. Scale A*B result with alpha value */ +/* 3. Add results from step1 & step2 */ +/* 4. Transpose and store results in */ +/* in row major order */ +/* 5. Output update is done only for */ +/* lower traingular matrix */ +/* NOTE: */ +/* Mask value is set to 1 if the */ +/* element exist else it is set to 0 */ +/* For m=1, mask = 2 to store real and */ +/* imag component */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Beta * C + */ +/* Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_ROW_LOWER(M, N) \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_ROW_LOWER_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + mask_n = (1 << ((ii+1)*2)) - 1; \ + SCALE_BETA(mask_n, rs_c) \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and store */ +/* number of elements based on the mask */ +/* in row major order where Beta = 0 */ +/* Output update is done only for */ +/* lower traingular matrix */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Input: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_ROW_LOWER_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + mask_n = (1 << ((ii+1)*2)) - 1; \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and store */ +/* number of elements based on the mask */ +/* in row major order where Beta = 0 */ +/* Output update is done only for */ +/* upper traingular matrix */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Inputs: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_ROW_UPPER(M, N) \ + if ((((beta->real) == 0) && (beta->imag) == 0) ) { STORE_ROW_UPPER_BZ(M, N) } \ + else \ + { \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + mask_n = ((1 << ((n_rem*2) - (ii*2))) - 1) << (ii*2); \ + SCALE_BETA(mask_n, rs_c) \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + } \ + +/****************************************/ +/* Operation: */ +/* Scale reg with alpha value and store */ +/* number of elements based on the mask */ +/* in row major order where Beta = 0 */ +/* Output update is done only for */ +/* upper traingular matrix */ +/* Elements: */ +/* Nx4 elements at a time */ +/* Inputs: */ +/* c_reg = A(real, imag) * B(real, img) */ +/* Output: */ +/* c_reg = Alpha * A(real, imag) * */ +/* B(real, img) */ +/****************************************/ +#define STORE_ROW_UPPER_BZ(M, N) \ + UNROLL_LOOP_FULL() \ + for(dim_t ii = 0; ii < M; ++ii) \ + { \ + SCALE_ALPHA(M) \ + mask_n = (((1 << ((n_rem*2) - (ii*2)))) - 1) << (ii*2); \ + _mm512_mask_storeu_pd(c + (rs_c * ii), mask_n, c_reg[ii]); \ + } \ + +/****************************************/ +/* Perform C = C * Beta + Alpha * A * B */ +/* Below functions are categorised based*/ +/* on row/col order and upper/lower */ +/* 1. Calculate n_rem for 4x4 blocks */ +/* 2. Set AVX register to zero which */ +/* are used during fma operation */ +/* 3. a_curr is pointer to matrix A, */ +/* updated based on m and panel stride*/ +/* 4. Mask is required for fringe case */ +/* if n_rem=1, mask_n = 0011b, 1real */ +/* and 1complex elements to be */ +/* accessed/stored */ +/* if n_rem=2, mask_n = 1111b, since */ +/* 2real and 2complex elements to be */ +/* accessed/stored */ +/* 5. Perfom A*B */ +/* 6. Store Beta*C + Alpha*A*B in to C */ +/****************************************/ +#define MAIN_LOOP_ROW(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + STORE_ROW(M, n_rem) \ + c += 4 * rs_c; \ + +#define MAIN_LOOP_COL(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + mask_n = (1 << (M*2)) - 1; \ + STORE_COL(M, n_rem) \ + c += 4 * rs_c; \ + +#define MAIN_LOOP_LOWER_DIAG_ROW(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + STORE_ROW_LOWER(M, n_rem) \ + c += 4 * rs_c; \ + +#define MAIN_LOOP_LOWER_DIAG_COL(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + STORE_COL_LOWER(M, n_rem) \ + c += 4 * rs_c; \ + +#define MAIN_LOOP_UPPER_DIAG_ROW(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + STORE_ROW_UPPER(M, n_rem) \ + c += 4 * rs_c; \ + +#define MAIN_LOOP_UPPER_DIAG_COL(M) \ + n_rem = n % 4; \ + if (n_rem == 0) n_rem = 4; \ + ZERO_REGISTERS() \ + b_curr = b; \ + a_curr = a + i * ps_a; \ + mask_n = (1 << (n_rem*2)) - 1; \ + GEMM_MxN(M, n_rem) \ + STORE_COL_UPPER(M, n_rem) \ + c += 4 * rs_c; \ + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is row major matrix */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_row + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr, *b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i =0; + + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_ROW(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_ROW(1); break; + case 2: + MAIN_LOOP_ROW(2); break; + case 3: + MAIN_LOOP_ROW(3); break; + } + +} + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is col major matrix */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_col + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr, *b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i =0; + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_COL(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_COL(1); break; + case 2: + MAIN_LOOP_COL(2); break; + case 3: + MAIN_LOOP_COL(3); break; + } + +} + +void bli_zgemmsup_rv_zen4_asm_4x4m + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + /* C is row stored*/ + if (cs_c == 1) { + bli_zgemmsup_rv_zen4_asm_4x4m_row + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + }else{ + /* C is col stored*/ + bli_zgemmsup_rv_zen4_asm_4x4m_col + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + } +} + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is row major matrix */ +/* Only lower portion below diagonal */ +/* elements are updated */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_lower_row + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr,*b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i = 0; + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_LOWER_DIAG_ROW(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_LOWER_DIAG_ROW(1); break; + case 2: + MAIN_LOOP_LOWER_DIAG_ROW(2); break; + case 3: + MAIN_LOOP_LOWER_DIAG_ROW(3); break; + } +} + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is col major matrix */ +/* Only lower portion below diagonal */ +/* elements are updated */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_lower_col + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr,*b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i = 0; + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_LOWER_DIAG_COL(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_LOWER_DIAG_COL(1); break; + case 2: + MAIN_LOOP_LOWER_DIAG_COL(2); break; + case 3: + MAIN_LOOP_LOWER_DIAG_COL(3); break; + } +} + +void bli_zgemmsup_rv_zen4_asm_4x4m_lower + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + /* C is row stored*/ + if (cs_c == 1) { + bli_zgemmsup_rv_zen4_asm_4x4m_lower_row + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + }else{ + /* C is col stored*/ + bli_zgemmsup_rv_zen4_asm_4x4m_lower_col + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + } +} + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is row major matrix */ +/* Only upper portion above diagonal */ +/* elements are updated */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_upper_row + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr, *b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i = 0; + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_UPPER_DIAG_ROW(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_UPPER_DIAG_ROW(1); break; + case 2: + MAIN_LOOP_UPPER_DIAG_ROW(2); break; + case 3: + MAIN_LOOP_UPPER_DIAG_ROW(3); break; + } +} + +/****************************************/ +/* Perform GEMMT operations */ +/* C matrix is col major matrix */ +/* Only upper portion above diagonal */ +/* elements are updated */ +/* Kernel size is 4x4 */ +/* For fringe cases, mask load/store */ +/* instruction is used */ +/****************************************/ +void bli_zgemmsup_rv_zen4_asm_4x4m_upper_col + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t ps_a = bli_auxinfo_ps_a( data ); + __m512d c_reg[4]; + __m512d c_imag_reg[4]; + __m512d a_reg[4]; + __m512d b_reg; + __m512d one_reg = _mm512_set1_pd(1); + __mmask8 mask_n; + dim_t n_rem; + dim_t m_main = m / 4; + dim_t m_rem = m % 4; + dcomplex *a_curr, *b_curr, *c = c_; + + /*Load real and complex value of alpha*/ + __m512d alpha_reg = _mm512_set1_pd(alpha->real); + __m512d alpha_imag_reg = _mm512_set1_pd(alpha->imag); + + /*Load real and complex value of beta*/ + __m512d beta_reg = _mm512_set1_pd(beta->real); + __m512d beta_imag_reg = _mm512_set1_pd(beta->imag); + + dim_t i = 0; + /*4x4 block is handled here*/ + for (i = 0; i < m_main; i++) + { + MAIN_LOOP_UPPER_DIAG_COL(4); + } + + /*Fringe blocks are handled here*/ + switch (m_rem) + { + case 1: + MAIN_LOOP_UPPER_DIAG_COL(1); break; + case 2: + MAIN_LOOP_UPPER_DIAG_COL(2); break; + case 3: + MAIN_LOOP_UPPER_DIAG_COL(3); break; + } +} + +void bli_zgemmsup_rv_zen4_asm_4x4m_upper + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a, inc_t cs_a, + dcomplex* restrict b, inc_t rs_b, inc_t cs_b, + dcomplex* restrict beta, + dcomplex* restrict c_, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + /* C is row stored*/ + if (cs_c == 1) { + bli_zgemmsup_rv_zen4_asm_4x4m_upper_row + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + }else{ + /* C is col stored*/ + bli_zgemmsup_rv_zen4_asm_4x4m_upper_col + ( + conja, + conjb, + m, + n, + k, + alpha, + a, rs_a, cs_a, + b, rs_b, cs_b, + beta, + c_, rs_c, cs_c, + data, + cntx ); + } +} diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c index a69d016b38..0fd2e7b034 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h index ae5023c400..6d7ff47d10 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #define INIT_REG \ diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c index 2e55b698ca..8e660a534e 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c index 08204eef20..8226d18ca7 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c @@ -1,9 +1,10 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -28,6 +29,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c index 690404628e..d60dee1cb0 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = (uint64_t)m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -901,7 +903,7 @@ void bli_dgemmsup_rv_zen4_asm_24x1 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1012,7 +1014,7 @@ void bli_dgemmsup_rv_zen4_asm_24x1 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1088,8 +1090,7 @@ void bli_dgemmsup_rv_zen4_asm_24x1 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1135,6 +1136,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1547,7 +1550,7 @@ void bli_dgemmsup_rv_zen4_asm_16x1 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1645,7 +1648,7 @@ void bli_dgemmsup_rv_zen4_asm_16x1 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1721,8 +1724,7 @@ void bli_dgemmsup_rv_zen4_asm_16x1 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1768,6 +1770,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2111,7 +2115,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -2195,7 +2199,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -2270,8 +2274,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c index 67a58c1b82..5130333f73 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1027,7 +1029,7 @@ void bli_dgemmsup_rv_zen4_asm_24x2 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1141,7 +1143,7 @@ void bli_dgemmsup_rv_zen4_asm_24x2 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1217,8 +1219,7 @@ void bli_dgemmsup_rv_zen4_asm_24x2 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1264,6 +1265,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1770,7 +1773,7 @@ void bli_dgemmsup_rv_zen4_asm_16x2 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1870,7 +1873,7 @@ void bli_dgemmsup_rv_zen4_asm_16x2 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1946,8 +1949,7 @@ void bli_dgemmsup_rv_zen4_asm_16x2 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1993,6 +1995,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2396,7 +2400,7 @@ void bli_dgemmsup_rv_zen4_asm_8x2 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -2481,7 +2485,7 @@ void bli_dgemmsup_rv_zen4_asm_8x2 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -2556,8 +2560,7 @@ void bli_dgemmsup_rv_zen4_asm_8x2 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c index ee6c3c573d..b2a66bc23f 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1149,7 +1151,7 @@ void bli_dgemmsup_rv_zen4_asm_24x3 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1263,7 +1265,7 @@ void bli_dgemmsup_rv_zen4_asm_24x3 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1339,8 +1341,7 @@ void bli_dgemmsup_rv_zen4_asm_24x3 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1386,6 +1387,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1983,7 +1986,7 @@ void bli_dgemmsup_rv_zen4_asm_16x3 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -2083,7 +2086,7 @@ void bli_dgemmsup_rv_zen4_asm_16x3 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -2159,8 +2162,7 @@ void bli_dgemmsup_rv_zen4_asm_16x3 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -2206,6 +2208,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2670,7 +2674,7 @@ void bli_dgemmsup_rv_zen4_asm_8x3 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -2755,7 +2759,7 @@ void bli_dgemmsup_rv_zen4_asm_8x3 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -2831,8 +2835,7 @@ void bli_dgemmsup_rv_zen4_asm_8x3 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c index f8a3968f7b..790f92fb28 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1232,7 +1234,7 @@ void bli_dgemmsup_rv_zen4_asm_24x4 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1346,7 +1348,7 @@ void bli_dgemmsup_rv_zen4_asm_24x4 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1422,8 +1424,7 @@ void bli_dgemmsup_rv_zen4_asm_24x4 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1469,6 +1470,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2161,7 +2164,7 @@ void bli_dgemmsup_rv_zen4_asm_16x4 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -2263,7 +2266,7 @@ void bli_dgemmsup_rv_zen4_asm_16x4 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -2339,8 +2342,7 @@ void bli_dgemmsup_rv_zen4_asm_16x4 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -2386,6 +2388,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2912,7 +2916,7 @@ void bli_dgemmsup_rv_zen4_asm_8x4 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -2998,7 +3002,7 @@ void bli_dgemmsup_rv_zen4_asm_8x4 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -3073,8 +3077,7 @@ void bli_dgemmsup_rv_zen4_asm_8x4 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c index d014358c84..7653e088ea 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1391,7 +1393,7 @@ void bli_dgemmsup_rv_zen4_asm_24x5 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) @@ -1515,7 +1517,7 @@ void bli_dgemmsup_rv_zen4_asm_24x5 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) @@ -1592,8 +1594,7 @@ void bli_dgemmsup_rv_zen4_asm_24x5 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1639,6 +1640,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2462,7 +2465,7 @@ void bli_dgemmsup_rv_zen4_asm_16x5 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -2570,7 +2573,7 @@ void bli_dgemmsup_rv_zen4_asm_16x5 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -2646,8 +2649,7 @@ void bli_dgemmsup_rv_zen4_asm_16x5 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -2693,6 +2695,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -3316,7 +3320,7 @@ void bli_dgemmsup_rv_zen4_asm_8x5 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -3405,7 +3409,7 @@ void bli_dgemmsup_rv_zen4_asm_8x5 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -3480,8 +3484,7 @@ void bli_dgemmsup_rv_zen4_asm_8x5 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c index db9ba7cae2..1578f66896 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1509,7 +1511,7 @@ void bli_dgemmsup_rv_zen4_asm_24x6 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1635,7 +1637,7 @@ void bli_dgemmsup_rv_zen4_asm_24x6 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1711,8 +1713,7 @@ void bli_dgemmsup_rv_zen4_asm_24x6 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1758,6 +1759,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2674,7 +2677,7 @@ void bli_dgemmsup_rv_zen4_asm_16x6 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -2784,7 +2787,7 @@ void bli_dgemmsup_rv_zen4_asm_16x6 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -2860,8 +2863,7 @@ void bli_dgemmsup_rv_zen4_asm_16x6 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -2907,6 +2909,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -3590,7 +3594,7 @@ void bli_dgemmsup_rv_zen4_asm_8x6 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -3680,7 +3684,7 @@ void bli_dgemmsup_rv_zen4_asm_8x6 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -3755,8 +3759,7 @@ void bli_dgemmsup_rv_zen4_asm_8x6 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c index 9e4194c118..f5e25a8693 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -419,6 +419,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1624,7 +1626,7 @@ void bli_dgemmsup_rv_zen4_asm_24x7 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) @@ -1751,7 +1753,7 @@ void bli_dgemmsup_rv_zen4_asm_24x7 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) @@ -1828,8 +1830,7 @@ void bli_dgemmsup_rv_zen4_asm_24x7 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -1875,6 +1876,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2836,7 +2839,7 @@ void bli_dgemmsup_rv_zen4_asm_16x7 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -2946,7 +2949,7 @@ void bli_dgemmsup_rv_zen4_asm_16x7 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -3022,8 +3025,7 @@ void bli_dgemmsup_rv_zen4_asm_16x7 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list @@ -3069,6 +3071,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -3811,7 +3815,7 @@ void bli_dgemmsup_rv_zen4_asm_8x7 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -3901,7 +3905,7 @@ void bli_dgemmsup_rv_zen4_asm_8x7 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -3976,8 +3980,7 @@ void bli_dgemmsup_rv_zen4_asm_8x7 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask), [mask_n0] "m" (mask_n0) : // register clobber list diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c index 065cbd5bb6..6e897c8119 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c @@ -4,32 +4,32 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - documentation and/or other materials provided with the distribution. - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY - OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ @@ -382,6 +382,8 @@ void bli_dgemmsup_rv_zen4_asm_24x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -1701,7 +1703,7 @@ void bli_dgemmsup_rv_zen4_asm_24x8 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -1833,7 +1835,7 @@ void bli_dgemmsup_rv_zen4_asm_24x8 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(16), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -1909,8 +1911,7 @@ void bli_dgemmsup_rv_zen4_asm_24x8 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -1955,6 +1956,8 @@ void bli_dgemmsup_rv_zen4_asm_16x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -2998,7 +3001,7 @@ void bli_dgemmsup_rv_zen4_asm_16x8 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) @@ -3110,7 +3113,7 @@ void bli_dgemmsup_rv_zen4_asm_16x8 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) sub(imm(8), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) @@ -3186,8 +3189,7 @@ void bli_dgemmsup_rv_zen4_asm_16x8 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", @@ -3232,6 +3234,8 @@ void bli_dgemmsup_rv_zen4_asm_8x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t m = (uint64_t)m0; + uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); @@ -4027,7 +4031,7 @@ void bli_dgemmsup_rv_zen4_asm_8x8 vbroadcastsd(mem(rax), zmm31) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8) cmp(imm(7), rdi) @@ -4118,7 +4122,7 @@ void bli_dgemmsup_rv_zen4_asm_8x8 SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) - mov(var(m0), rdi) + mov(var(m), rdi) cmp(imm(8), rdi) JZ(.UPDATE8BZ) cmp(imm(7), rdi) @@ -4193,8 +4197,7 @@ void bli_dgemmsup_rv_zen4_asm_8x8 [c] "m" (c), [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), - [n0] "m" (n0), - [m0] "m" (m0), + [m] "m" (m), [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", diff --git a/kernels/zen4/aocl_smart/bli_aocl_smart.c b/kernels/zen4/aocl_smart/bli_aocl_smart.c index 96e45b7139..dd8539bab5 100644 --- a/kernels/zen4/aocl_smart/bli_aocl_smart.c +++ b/kernels/zen4/aocl_smart/bli_aocl_smart.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -66,6 +66,33 @@ bool bli_cntx_gemmsup_thresh_is_met_zen4( obj_t* a, obj_t* b, obj_t* c, cntx_t* if((m < 5000) && (n < 5000) && (k < 5000)) return TRUE; return FALSE; } + else if( dt == BLIS_DCOMPLEX ) + { + dim_t k = bli_obj_width_after_trans( a ); + dim_t m, n; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + if ( bli_cntx_l3_sup_ker_dislikes_storage_of( c, stor_id, cntx ) ) + { + m = bli_obj_width(c); + n = bli_obj_length(c); + } + else + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + } + // For skinny sizes where m and/or n is small + // The threshold for m is a single value, but for n, it is + // also based on the packing size of A, since the kernels are + // column preferential + if( ( m <= 84 ) || ( ( n <= 84 ) && ( m < 4000 ) ) ) return TRUE; + + // For all combinations in small sizes + if( ( m <= 216 ) && ( n <= 216 ) && ( k <= 216 ) ) return TRUE; + return FALSE; + } else return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx ); } diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h index 82872ac942..b27984731a 100644 --- a/kernels/zen4/bli_kernels_zen4.h +++ b/kernels/zen4/bli_kernels_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,22 +34,80 @@ // -- level-1v -- +// addv (intrinsics) +ADDV_KER_PROT( double, d, addv_zen_int_avx512 ) + // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) // scalv (AVX512 intrinsics) SCALV_KER_PROT( float, s, scalv_zen_int_avx512 ) -SCALV_KER_PROT( double, d, scalv_zen_int_avx512 ) +BLIS_EXPORT_BLIS SCALV_KER_PROT( double, d, scalv_zen_int_avx512 ) +SCALV_KER_PROT( scomplex, c, scalv_zen_int_avx512 ) +SCALV_KER_PROT( dcomplex, z, scalv_zen_int_avx512 ) SCALV_KER_PROT( dcomplex, z, dscalv_zen_int_avx512) // ZDSCAL kernel +// setv (intrinsics) +SETV_KER_PROT(float, s, setv_zen_int_avx512) +SETV_KER_PROT(double, d, setv_zen_int_avx512) +SETV_KER_PROT(dcomplex, z, setv_zen_int_avx512) + // dotv (intrinsics) DOTV_KER_PROT( float, s, dotv_zen_int_avx512 ) DOTV_KER_PROT( double, d, dotv_zen_int_avx512 ) +DOTV_KER_PROT( dcomplex, z, dotv_zen_int_avx512 ) +DOTV_KER_PROT( dcomplex, z, dotv_zen4_asm_avx512 ) // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int_avx512 ) -AXPYV_KER_PROT( double, d, axpyv_zen_int_avx512 ) +BLIS_EXPORT_BLIS AXPYV_KER_PROT( double, d, axpyv_zen_int_avx512 ) +AXPYV_KER_PROT( dcomplex, z, axpyv_zen_int_avx512 ) + +// axpbyv ( intrinsics ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int_avx512 ); + +// axpyf (intrinsics) +AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_2_avx512 ) +AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4_avx512 ) +AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_8_avx512 ) + +// axpyf (intrinsics) +AXPYF_KER_PROT( double, d, axpyf_zen_int_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int2_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int4_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int6_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int8_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int12_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int16_avx512 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int32_avx512 ) +#ifdef BLIS_ENABLE_OPENMP +AXPYF_KER_PROT( double, d, axpyf_zen_int32_avx512_mt ) +#endif + +// dotxf (intrinsics) +DOTXF_KER_PROT( double, d, dotxf_zen_int_avx512 ) + +// copyv (intrinsics) +// COPYV_KER_PROT( float, s, copyv_zen_int_avx512 ) +// COPYV_KER_PROT( double, d, copyv_zen_int_avx512 ) +// COPYV_KER_PROT( dcomplex, z, copyv_zen_int_avx512 ) + +// copyv (asm) +COPYV_KER_PROT( float, s, copyv_zen4_asm_avx512 ) +COPYV_KER_PROT( double, d, copyv_zen4_asm_avx512 ) +COPYV_KER_PROT( dcomplex, z, copyv_zen4_asm_avx512 ) + +// scal2v (intrinsics) +SCAL2V_KER_PROT(double, d, scal2v_zen_int_avx512) + +// dotxv (intrinsics) +DOTXV_KER_PROT( dcomplex, z, dotxv_zen_int_avx512 ) + +// dotxf (intrinsics) +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_8_avx512 ) +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_4_avx512 ) +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_2_avx512 ) GEMMTRSM_UKR_PROT( double, d, gemmtrsm_l_zen_asm_16x14) GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_zen_asm_16x14) @@ -136,6 +194,10 @@ TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 ) TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 ) TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 ) TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 ) +TRSMSMALL_KER_PROT( z, trsm_small_AutXB_AlXB_AVX512 ) +TRSMSMALL_KER_PROT( z, trsm_small_XAltB_XAuB_AVX512 ) +TRSMSMALL_KER_PROT( z, trsm_small_XAutB_XAlB_AVX512 ) +TRSMSMALL_KER_PROT( z, trsm_small_AltXB_AuXB_AVX512 ) #ifdef BLIS_ENABLE_OPENMP TRSMSMALL_PROT(trsm_small_mt_AVX512) @@ -154,6 +216,21 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x1m) GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8) GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_16x8) GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_8x8) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_8x8m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_8x8m_lower) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_8x8m_upper) + +/* DGEMMT 24x8 triangular kernels */ +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_lower_0) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_lower_1) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_lower_2) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_upper_0) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_upper_1) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x8m_upper_2) + +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen4_asm_4x4m) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen4_asm_4x4m_lower) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen4_asm_4x4m_upper) GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_24x7) GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen4_asm_16x7) @@ -216,6 +293,26 @@ err_t bli_dgemm_24x8_avx512_k1_nn double* c, const inc_t ldc ); +void bli_dnorm2fv_unb_var1_avx512 + ( + dim_t n, + double* x, inc_t incx, + double* norm, + cntx_t* cntx + ); + +err_t bli_zgemm_16x4_avx512_k1_nn +( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc +); + // threshold functions bool bli_cntx_gemmsup_thresh_is_met_zen4 ( diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index d5fa298c2d..2c2a67a62a 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -25,7 +25,7 @@ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS dim_tERRUPTION) HOWEVER CAUSED AND ON ANY + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -39,11 +39,176 @@ #include "lpgemm_f32_kern_macros.h" -#ifdef LPGEMM_BF16_NOT_SUPPORTED +#ifdef LPGEMM_BF16_JIT -// BF16 ISA is not supported by gcc < 10. Use a dummy kernel here. +typedef void (*jit_kernel)(lpgemm_jit_params_t*, lpgemm_post_op_attr*, lpgemm_post_op*); + +// BF16 ISA is not supported by gcc < 10. Use a JIT-generated kernel here. LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) -{} +{ + jit_kernel kernel_fp; + dim_t MR = 6; + dim_t NR = 64; + + dim_t post_op_temp_c_i = post_ops_attr.post_op_c_i; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 & 1; + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + + // Fill params_t struct with all the data that will be required + // during execution of the JIT kernel. + lpgemm_jit_params_t params; + params.m = m0; params.n = n0; params.k = k0; + params.rs_a = rs_a; params.cs_a = cs_a; + params.ps_a2 = ps_a * sizeof( bfloat16 ) * MR; + params.rs_b = rs_b; params.cs_b = cs_b; + params.rs_c = rs_c; params.cs_c = 1; + params.alpha = ( float* )α + params.beta = ( float* )β + params.m_iter = m_full_pieces; + params.k_iter_before_prefetch = k_full_pieces - value; + params.k_iter_after_prefetch = value; + params.k_left = k_partial_pieces; + params.a = ( bfloat16* )a; params.b = ( bfloat16* )b; params.c = ( float* )c; + + + dim_t n0_16 = n0 / NUM_F32_ELEMS_PER_ZMM; + + // n_fringe case + // if n < NR, handle them using n-fringe kernels. + if ( n0 < NR ) + { + dim_t n0_rem = n0 % NUM_F32_ELEMS_PER_ZMM; + + // KC when not multiple of 2 will have padding to make it multiple of + // 2 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + k0_updated += ( k0_updated & 0x1 ); + + + // Split dim_to multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + + // Handles case where n0 >=16. + if( n0 > n0_rem ) + { + params.rs_b = ( ( rs_b / 4 ) * ( n0_16 ) ); + + // kernel with m_iter loop. + if( m0 >= MR ) + { + kernel_fp = lpgemm_get_jit_kernel( 0, n0_16 ); + + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } + // Handle m_fringe case. + if( m_partial_pieces ) + { + post_ops_attr.post_op_c_i += m_full_pieces_loop_limit; + params.a += m_full_pieces_loop_limit * ps_a; + params.c += m_full_pieces_loop_limit * rs_c; + kernel_fp = lpgemm_get_jit_kernel( m_partial_pieces, n0_16 ); + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } + params.b = ( bfloat16* )b + ( n0 - n0_rem ) * k0_updated; + params.c = ( float* )c + ( n0 - n0_rem ); + + post_ops_attr.post_op_c_j += n0 - n0_rem; + } + + // Handles case where n0_rem < 16 + // We use mask loads/stores in this case. + if ( n0_rem > 0 ) + { + params.a = ( bfloat16* )a; + + params.mask16 = 0xFFFFFFFF >> ( NUM_F32_ELEMS_PER_ZMM - n0_rem); + params.mask32 = 0xFFFF >> ( NUM_F32_ELEMS_PER_ZMM - n0_rem ); + + params.rs_b = ( ( rs_b / 4 ) * 1 ); + post_ops_attr.post_op_c_i = post_op_temp_c_i; + + // kernel with m_iter loop + if( m0 >= MR ) + { + kernel_fp = lpgemm_get_jit_kernel( 0, 0 ); + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } + // Handle m_fringe case. + if( m_partial_pieces ) + { + post_ops_attr.post_op_c_i += m_full_pieces_loop_limit; + params.a += m_full_pieces_loop_limit * ps_a; + params.c += m_full_pieces_loop_limit * rs_c; + kernel_fp = lpgemm_get_jit_kernel( m_partial_pieces, 0 ); + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } + + // No leftover n-fringe after this point. + } + return; + } + + // Main 6x64 kernel with m_iter loop. + if( m0 >= MR ) + { + kernel_fp = lpgemm_get_jit_kernel( 0, n0_16 ); + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } + + // Handle m_fringe case here. + if( m_partial_pieces ) + { + post_ops_attr.post_op_c_i += m_full_pieces_loop_limit; + + params.a += m_full_pieces_loop_limit * ps_a; + params.c += m_full_pieces_loop_limit * rs_c; + + kernel_fp = lpgemm_get_jit_kernel( m_partial_pieces, n0_16 ); + ( kernel_fp )( + ¶ms, + &post_ops_attr, + post_ops_list + ); + } +} #else @@ -59,7 +224,10 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) &&POST_OPS_GELU_TANH_6x64, &&POST_OPS_GELU_ERF_6x64, &&POST_OPS_CLIP_6x64, - &&POST_OPS_DOWNSCALE_6x64 + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64, + &&POST_OPS_MATRIX_MUL_6x64 }; dim_t MR = 6; dim_t NR = 64; @@ -72,6 +240,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) dim_t k_partial_pieces = k0 % 2; int16_t a_kfringe_buf = 0; + if ( n0 < NR ) { @@ -696,18 +865,29 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -789,24 +969,39 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); - selector2 = + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 1 ) ); - selector3 = + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 2 ) ); - selector4 = + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 4 ) ); - __m512 selector6 = + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 5 ) ); + post_ops_attr.post_op_c_i + 5 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1280,84 +1475,493 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x64: -{ - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector3,zero_point2); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector2,zero_point1); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector2,zero_point1); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 48-63] - MULRND_F32(c_float_1p3,1,3); + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 48-63] - MULRND_F32(c_float_2p3,2,3); + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 48-63] - MULRND_F32(c_float_3p3,3,3); + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 32-47] - MULRND_F32(c_float_4p2,4,2); + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 48-63] - MULRND_F32(c_float_4p3,4,3); + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); // c[5, 0-15] - MULRND_F32(c_float_5p0,5,0); + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); // c[5, 16-31] - MULRND_F32(c_float_5p1,5,1); + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); // c[5, 32-47] - MULRND_F32(c_float_5p2,5,2); + SWISH_F32_AVX512_DEF(c_float_5p2, selector1, al_in, r, r2, z, dn, ex_out); // c[5, 48-63] - MULRND_F32(c_float_5p3,5,3); + SWISH_F32_AVX512_DEF(c_float_5p3, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR -} - + } POST_OPS_6x64_DISABLE: ; @@ -1447,7 +2051,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) } - // Case where the output C matrix is float + // Case where the output C matrix is float else { // Store the results. @@ -1605,5 +2209,5 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) } } -#endif //LPGEMM_BF16_NOT_SUPPORTED +#endif //LPGEMM_BF16_JIT #endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16s4f32of32_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16s4f32of32_amd512vnni.c new file mode 100644 index 0000000000..d0e57618e0 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16s4f32of32_amd512vnni.c @@ -0,0 +1,2168 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" +#include "../int4_utils_avx512.h" + +#ifndef LPGEMM_BF16_JIT + +// 6x64 bf16 kernel +LPGEMM_MAIN_KERN(bfloat16, int8_t, float, bf16s4f32of32_6x64m) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64, + &&POST_OPS_MATRIX_MUL_6x64 + }; + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + + if ( n0 < NR ) + { + dim_t n0_rem = n0 % 16; + + // Split dim_to multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n0 / 48; + dim_t n0_32 = n0 / 32; + dim_t n0_16 = n0 / 16; + + // KC when not multiple of 2 will have padding to make it multiple of + // 2 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + k0_updated += (k0_updated & 0x1); + + if ( n0_48 == 1 ) + { + lpgemm_rowvar_bf16s4f32of32_6x48m + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 3 ), cs_b, + c, rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + + b = b + ( ( 48 * k0_updated ) / 2 ); // k0x48 packed contiguosly. + c = c + 48; + post_ops_attr.post_op_c_j += 48; + post_ops_attr.pre_op_off += 48; + } + + else if ( n0_32 == 1 ) + { + lpgemm_rowvar_bf16s4f32of32_6x32m + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 2 ), cs_b, + c, rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + + b = b + ( ( 32 * k0_updated ) / 2 ); // k0x32 packed contiguosly. + c = c + 32; + post_ops_attr.post_op_c_j += 32; + post_ops_attr.pre_op_off += 32; + } + + else if ( n0_16 == 1 ) + { + lpgemm_rowvar_bf16s4f32of32_6x16m + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + + b = b + ( ( 16 * k0_updated ) / 2 ); // k0x16 packed contiguosly. + c = c + 16; + post_ops_attr.post_op_c_j += 16; + post_ops_attr.pre_op_off += 16; + } + + if ( n0_rem > 0 ) + { + lpgemm_rowvar_bf16s4f32of32_6xlt16m + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + + // No leftover fringe after this podint. + } + return; + } + + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + __m512 c_float_5p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces - value; kr += 1 ) + { + // Broadcast a[0,kr:kr+2] + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + + _mm_prefetch(c + (rs_c * (ir + 0)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (3 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 1)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (3 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 2)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (3 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 3)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (3 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 4)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (3 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 5)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (2 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (3 * 16), _MM_HINT_T1); + + for (dim_t kr = k_full_pieces - value; kr < k_full_pieces; kr += 1) + { + // The instructions are arranged in a mixed way to reduce data + // chain dependencies. + + // b0 = (__m512bh)_mm512_loadu_epi16(b + (rs_b * kr) + (cs_b * 0)); + + // Broadcast a[0,kr:kr+2] + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 0) + (cs_a * kr))); + + // b1 = (__m512bh)_mm512_loadu_epi16(b + (rs_b * kr) + (cs_b * 1)); + // b2 = (__m512bh)_mm512_loadu_epi16(b + (rs_b * kr) + (cs_b * 2)); + // b3 = (__m512bh)_mm512_loadu_epi16(b + (rs_b * kr) + (cs_b * 3)); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps(c_float_0p0, a_bf16_0, b0); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 1) + (cs_a * kr))); + + c_float_0p1 = _mm512_dpbf16_ps(c_float_0p1, a_bf16_0, b1); + c_float_0p2 = _mm512_dpbf16_ps(c_float_0p2, a_bf16_0, b2); + c_float_0p3 = _mm512_dpbf16_ps(c_float_0p3, a_bf16_0, b3); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps(c_float_1p0, a_bf16_1, b0); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 2) + (cs_a * kr))); + + c_float_1p1 = _mm512_dpbf16_ps(c_float_1p1, a_bf16_1, b1); + c_float_1p2 = _mm512_dpbf16_ps(c_float_1p2, a_bf16_1, b2); + c_float_1p3 = _mm512_dpbf16_ps(c_float_1p3, a_bf16_1, b3); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps(c_float_2p0, a_bf16_0, b0); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 3) + (cs_a * kr))); + + c_float_2p1 = _mm512_dpbf16_ps(c_float_2p1, a_bf16_0, b1); + c_float_2p2 = _mm512_dpbf16_ps(c_float_2p2, a_bf16_0, b2); + c_float_2p3 = _mm512_dpbf16_ps(c_float_2p3, a_bf16_0, b3); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps(c_float_3p0, a_bf16_1, b0); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 4) + (cs_a * kr))); + + c_float_3p1 = _mm512_dpbf16_ps(c_float_3p1, a_bf16_1, b1); + c_float_3p2 = _mm512_dpbf16_ps(c_float_3p2, a_bf16_1, b2); + c_float_3p3 = _mm512_dpbf16_ps(c_float_3p3, a_bf16_1, b3); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps(c_float_4p0, a_bf16_0, b0); + + // Broadcast a[5,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 5) + (cs_a * kr))); + + c_float_4p1 = _mm512_dpbf16_ps(c_float_4p1, a_bf16_0, b1); + c_float_4p2 = _mm512_dpbf16_ps(c_float_4p2, a_bf16_0, b2); + c_float_4p3 = _mm512_dpbf16_ps(c_float_4p3, a_bf16_0, b3); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps(c_float_5p0, a_bf16_1, b0); + c_float_5p1 = _mm512_dpbf16_ps(c_float_5p1, a_bf16_1, b1); + c_float_5p2 = _mm512_dpbf16_ps(c_float_5p2, a_bf16_1, b2); + c_float_5p3 = _mm512_dpbf16_ps(c_float_5p3, a_bf16_1, b3); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + // b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + // b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 5) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps ( alpha ); + __m512 selector2 = _mm512_set1_ps ( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + c_float_5p3 = _mm512_mul_ps( selector1, c_float_5p3 ); + + } + + // Scale C by beta. + if ( beta != 0 ) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,ir,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,ir,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,ir,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,ir,0,3,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,ir,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,ir,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,ir,1,2,selector1,selector2) + + // c[1,48-63] + BF16_F32_BETA_OP(c_float_1p3,ir,1,3,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,ir,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,ir,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,ir,2,2,selector1,selector2) + + // c[2,48-63] + BF16_F32_BETA_OP(c_float_2p3,ir,2,3,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,ir,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,ir,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,ir,3,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_3p3,ir,3,3,selector1,selector2) + + // c[4,0-15] + BF16_F32_BETA_OP(c_float_4p0,ir,4,0,selector1,selector2) + + // c[4,16-31] + BF16_F32_BETA_OP(c_float_4p1,ir,4,1,selector1,selector2) + + // c[4,32-47] + BF16_F32_BETA_OP(c_float_4p2,ir,4,2,selector1,selector2) + + // c[4,48-63] + BF16_F32_BETA_OP(c_float_4p3,ir,4,3,selector1,selector2) + + // c[5,0-15] + BF16_F32_BETA_OP(c_float_5p0,ir,5,0,selector1,selector2) + + // c[5,16-31] + BF16_F32_BETA_OP(c_float_5p1,ir,5,1,selector1,selector2) + + // c[5,32-47] + BF16_F32_BETA_OP(c_float_5p2,ir,5,2,selector1,selector2) + + // c[5,48-63] + BF16_F32_BETA_OP(c_float_5p3,ir,5,3,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,ir,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,ir,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,ir,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,ir,0,3,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,ir,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,ir,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,ir,1,2,selector1,selector2) + + // c[1,48-63] + F32_F32_BETA_OP(c_float_1p3,ir,1,3,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,ir,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,ir,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,ir,2,2,selector1,selector2) + + // c[2,48-63] + F32_F32_BETA_OP(c_float_2p3,ir,2,3,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,ir,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,ir,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,ir,3,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_3p3,ir,3,3,selector1,selector2) + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0,ir,4,0,selector1,selector2) + + // c[4,16-31] + F32_F32_BETA_OP(c_float_4p1,ir,4,1,selector1,selector2) + + // c[4,32-47] + F32_F32_BETA_OP(c_float_4p2,ir,4,2,selector1,selector2) + + // c[4,48-63] + F32_F32_BETA_OP(c_float_4p3,ir,4,3,selector1,selector2) + + // c[5,0-15] + F32_F32_BETA_OP(c_float_5p0,ir,5,0,selector1,selector2) + + // c[5,16-31] + F32_F32_BETA_OP(c_float_5p1,ir,5,1,selector1,selector2) + + // c[5,32-47] + F32_F32_BETA_OP(c_float_5p2,ir,5,2,selector1,selector2) + + // c[5,48-63] + F32_F32_BETA_OP(c_float_5p3,ir,5,3,selector1,selector2) + + } + + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector6, c_float_5p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_max_ps( selector1, c_float_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_5p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32_AVX512(c_float_4p3, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 16-31] + GELU_TANH_F32_AVX512(c_float_5p1, r, r2, x, z, dn, x_tanh, q) + + // c[5, 32-47] + GELU_TANH_F32_AVX512(c_float_5p2, r, r2, x, z, dn, x_tanh, q) + + // c[5, 48-63] + GELU_TANH_F32_AVX512(c_float_5p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32_AVX512(c_float_4p3, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + // c[5, 16-31] + GELU_ERF_F32_AVX512(c_float_5p1, r, x, x_erf) + + // c[5, 32-47] + GELU_ERF_F32_AVX512(c_float_5p2, r, x, x_erf) + + // c[5, 48-63] + GELU_ERF_F32_AVX512(c_float_5p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_CLIP_6x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + // c[4, 48-63] + CLIP_F32_AVX512(c_float_4p3, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + // c[5, 16-31] + CLIP_F32_AVX512(c_float_5p1, min, max) + + // c[5, 32-47] + CLIP_F32_AVX512(c_float_5p2, min, max) + + // c[5, 48-63] + CLIP_F32_AVX512(c_float_5p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector3,zero_point2); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector2,zero_point1); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector2,zero_point1); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(c_float_5p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 48-63] + SWISH_F32_AVX512_DEF(c_float_5p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_4p3,4,3); + + // c[5, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_5p2,5,2); + + // c[5, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_5p3,5,3); + + } + + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_float_4p3 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + + // c[5,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_float_5p3 ); + + } + + a = a + ( MR * ps_a ); + post_ops_attr.post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + // In cases where A matrix is packed cs_a is set to 12, since the + // next column in a given row is accessed after 2*6 elements, where + // 6 is MR and 2 elements are broadcasted each time from A (bf16). + // In fringe case, where m < MR, the next column will be after m'*2 + // elements, and subsequently following adjustment of cs_a is + // required before calling m fringe kernels. + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16s4f32of32_5x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16s4f32of32_4x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16s4f32of32_3x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16s4f32of32_2x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16s4f32of32_1x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + } +} + +#endif //LPGEMM_BF16_JIT +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_6x64rowmajor_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_6x64rowmajor_bf16_amd512vnni.c new file mode 100644 index 0000000000..b25546c7e2 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_6x64rowmajor_bf16_amd512vnni.c @@ -0,0 +1,1537 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" + +#ifdef LPGEMM_BF16_JIT + +LPGEMM_ELTWISE_OPS_KERNEL(bfloat16,float,bf16of32_6x64) +{ + // Not supported! +} + +#else + +LPGEMM_ELTWISE_OPS_KERNEL(bfloat16,float,bf16of32_6x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_OPS_DISABLE, + &&POST_OPS_BIAS_6x64_OPS, + &&POST_OPS_RELU_6x64_OPS, + &&POST_OPS_RELU_SCALE_6x64_OPS, + &&POST_OPS_GELU_TANH_6x64_OPS, + &&POST_OPS_GELU_ERF_6x64_OPS, + &&POST_OPS_CLIP_6x64_OPS, + &&POST_OPS_DOWNSCALE_6x64_OPS, + &&POST_OPS_MATRIX_ADD_6x64_OPS, + &&POST_OPS_SWISH_6x64_OPS, + &&POST_OPS_MATRIX_MUL_6x64_OPS + }; + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + __m512 c_float_5p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + uint64_t orig_post_op_c_j = post_ops_attr.post_op_c_j; + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 2ndx64 block. + c_float_1p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 1 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_1p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_1p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_1p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 3rdx64 block. + c_float_2p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 2 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_2p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 2 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_2p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 2 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_2p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 2 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 4thx64 block. + c_float_3p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 3 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_3p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 3 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_3p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 3 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_3p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 3 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 5thx64 block. + c_float_4p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 4 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_4p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 4 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_4p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 4 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_4p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 4 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 6thx64 block. + c_float_5p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( ir + 5 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_5p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( ir + 5 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_5p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( ir + 5 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_5p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( ir + 5 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_6x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector6, c_float_5p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_max_ps( selector1, c_float_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_5p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32_AVX512(c_float_4p3, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 16-31] + GELU_TANH_F32_AVX512(c_float_5p1, r, r2, x, z, dn, x_tanh, q) + + // c[5, 32-47] + GELU_TANH_F32_AVX512(c_float_5p2, r, r2, x, z, dn, x_tanh, q) + + // c[5, 48-63] + GELU_TANH_F32_AVX512(c_float_5p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32_AVX512(c_float_4p3, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + // c[5, 16-31] + GELU_ERF_F32_AVX512(c_float_5p1, r, x, x_erf) + + // c[5, 32-47] + GELU_ERF_F32_AVX512(c_float_5p2, r, x, x_erf) + + // c[5, 48-63] + GELU_ERF_F32_AVX512(c_float_5p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + // c[4, 48-63] + CLIP_F32_AVX512(c_float_4p3, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + // c[5, 16-31] + CLIP_F32_AVX512(c_float_5p1, min, max) + + // c[5, 32-47] + CLIP_F32_AVX512(c_float_5p2, min, max) + + // c[5, 48-63] + CLIP_F32_AVX512(c_float_5p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector3,zero_point2); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector2,zero_point1); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector2,zero_point1); + + // c[5, 48-63] + SCL_MULRND_F32(c_float_5p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(c_float_5p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 48-63] + SWISH_F32_AVX512_DEF(c_float_5p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + + // c[1, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p0,k0,1,0); + // c[1, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p1,k1,1,16); + // c[1, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p2,k2,1,32); + // c[1, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p3,k3,1,48); + + // c[2, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p0,k0,2,0); + // c[2, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p1,k1,2,16); + // c[2, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p2,k2,2,32); + // c[2, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p3,k3,2,48); + + // c[3, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p0,k0,3,0); + // c[3, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p1,k1,3,16); + // c[3, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p2,k2,3,32); + // c[3, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p3,k3,3,48); + + // c[4, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p0,k0,4,0); + // c[4, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p1,k1,4,16); + // c[4, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p2,k2,4,32); + // c[4, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p3,k3,4,48); + + // c[5, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_5p0,k0,5,0); + // c[5, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_5p1,k1,5,16); + // c[5, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_5p2,k2,5,32); + // c[5, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_5p3,k3,5,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_1p0 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_1p1 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_1p2 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_1p3 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_2p0 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_2p1 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_2p2 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_2p3 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_3p0 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_3p1 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_3p2 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_3p3 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_4p0 ); + // c[4,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_4p1 ); + // c[4,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_4p2 ); + // c[4,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_4p3 ); + + // c[5,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_5p0 ); + // c[5,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_5p1 ); + // c[5,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_5p2 ); + // c[5,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_5p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } + + post_ops_attr.post_op_c_j = orig_post_op_c_j; + post_ops_attr.post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + dim_t dsize = sizeof( float ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + dsize = sizeof( bfloat16 ); + } + + int8_t* b_i = ( int8_t* )b; + if ( m_partial_pieces == 5 ) + { + lpgemm_eltwise_ops_kernel_bf16of32_5x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + lpgemm_eltwise_ops_kernel_bf16of32_4x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + lpgemm_eltwise_ops_kernel_bf16of32_3x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + lpgemm_eltwise_ops_kernel_bf16of32_2x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + lpgemm_eltwise_ops_kernel_bf16of32_1x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + } +} + +#endif //LPGEMM_BF16_JIT +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_m_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_m_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..16c2f97523 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_eltwise_ops_m_fringe_bf16_amd512vnni.c @@ -0,0 +1,4348 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" + +#ifndef LPGEMM_BF16_JIT + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_5x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_OPS_DISABLE, + &&POST_OPS_BIAS_5x64_OPS, + &&POST_OPS_RELU_5x64_OPS, + &&POST_OPS_RELU_SCALE_5x64_OPS, + &&POST_OPS_GELU_TANH_5x64_OPS, + &&POST_OPS_GELU_ERF_5x64_OPS, + &&POST_OPS_CLIP_5x64_OPS, + &&POST_OPS_DOWNSCALE_5x64_OPS, + &&POST_OPS_MATRIX_ADD_5x64_OPS, + &&POST_OPS_SWISH_5x64_OPS, + &&POST_OPS_MATRIX_MUL_5x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 2ndx64 block. + c_float_1p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_1p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_1p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_1p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 3rdx64 block. + c_float_2p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_2p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_2p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_2p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 4thx64 block. + c_float_3p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_3p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_3p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_3p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 5thx64 block. + c_float_4p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_4p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_4p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_4p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_5x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32_AVX512(c_float_4p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32_AVX512(c_float_4p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + // c[4, 48-63] + CLIP_F32_AVX512(c_float_4p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + dim_t ir = 0; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + + // c[1, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p0,k0,1,0); + // c[1, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p1,k1,1,16); + // c[1, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p2,k2,1,32); + // c[1, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p3,k3,1,48); + + // c[2, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p0,k0,2,0); + // c[2, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p1,k1,2,16); + // c[2, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p2,k2,2,32); + // c[2, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p3,k3,2,48); + + // c[3, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p0,k0,3,0); + // c[3, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p1,k1,3,16); + // c[3, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p2,k2,3,32); + // c[3, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p3,k3,3,48); + + // c[4, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p0,k0,4,0); + // c[4, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p1,k1,4,16); + // c[4, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p2,k2,4,32); + // c[4, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_4p3,k3,4,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_1p0 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_1p1 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_1p2 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_1p3 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_2p0 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_2p1 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_2p2 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_2p3 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_3p0 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_3p1 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_3p2 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_3p3 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_4p0 ); + // c[4,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_4p1 ); + // c[4,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_4p2 ); + // c[4,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_4p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_4x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_OPS_DISABLE, + &&POST_OPS_BIAS_4x64_OPS, + &&POST_OPS_RELU_4x64_OPS, + &&POST_OPS_RELU_SCALE_4x64_OPS, + &&POST_OPS_GELU_TANH_4x64_OPS, + &&POST_OPS_GELU_ERF_4x64_OPS, + &&POST_OPS_CLIP_4x64_OPS, + &&POST_OPS_DOWNSCALE_4x64_OPS, + &&POST_OPS_MATRIX_ADD_4x64_OPS, + &&POST_OPS_SWISH_4x64_OPS, + &&POST_OPS_MATRIX_MUL_4x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 2ndx64 block. + c_float_1p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_1p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_1p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_1p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 3rdx64 block. + c_float_2p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_2p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_2p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_2p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 4thx64 block. + c_float_3p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_3p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_3p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_3p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_4x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + dim_t ir = 0; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + + // c[1, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p0,k0,1,0); + // c[1, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p1,k1,1,16); + // c[1, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p2,k2,1,32); + // c[1, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p3,k3,1,48); + + // c[2, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p0,k0,2,0); + // c[2, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p1,k1,2,16); + // c[2, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p2,k2,2,32); + // c[2, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p3,k3,2,48); + + // c[3, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p0,k0,3,0); + // c[3, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p1,k1,3,16); + // c[3, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p2,k2,3,32); + // c[3, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_3p3,k3,3,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_1p0 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_1p1 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_1p2 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_1p3 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_2p0 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_2p1 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_2p2 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_2p3 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_3p0 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_3p1 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_3p2 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_3p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_3x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_OPS_DISABLE, + &&POST_OPS_BIAS_3x64_OPS, + &&POST_OPS_RELU_3x64_OPS, + &&POST_OPS_RELU_SCALE_3x64_OPS, + &&POST_OPS_GELU_TANH_3x64_OPS, + &&POST_OPS_GELU_ERF_3x64_OPS, + &&POST_OPS_CLIP_3x64_OPS, + &&POST_OPS_DOWNSCALE_3x64_OPS, + &&POST_OPS_MATRIX_ADD_3x64_OPS, + &&POST_OPS_SWISH_3x64_OPS, + &&POST_OPS_MATRIX_MUL_3x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 2ndx64 block. + c_float_1p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_1p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_1p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_1p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 3rdx64 block. + c_float_2p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_2p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_2p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_2p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_3x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + dim_t ir = 0; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + + // c[1, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p0,k0,1,0); + // c[1, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p1,k1,1,16); + // c[1, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p2,k2,1,32); + // c[1, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p3,k3,1,48); + + // c[2, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p0,k0,2,0); + // c[2, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p1,k1,2,16); + // c[2, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p2,k2,2,32); + // c[2, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_2p3,k3,2,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_1p0 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_1p1 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_1p2 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_1p3 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_2p0 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_2p1 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_2p2 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_2p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_2x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_OPS_DISABLE, + &&POST_OPS_BIAS_2x64_OPS, + &&POST_OPS_RELU_2x64_OPS, + &&POST_OPS_RELU_SCALE_2x64_OPS, + &&POST_OPS_GELU_TANH_2x64_OPS, + &&POST_OPS_GELU_ERF_2x64_OPS, + &&POST_OPS_CLIP_2x64_OPS, + &&POST_OPS_DOWNSCALE_2x64_OPS, + &&POST_OPS_MATRIX_ADD_2x64_OPS, + &&POST_OPS_SWISH_2x64_OPS, + &&POST_OPS_MATRIX_MUL_2x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // 2ndx64 block. + c_float_1p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_1p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_1p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_1p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_2x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + dim_t ir = 0; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + + // c[1, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p0,k0,1,0); + // c[1, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p1,k1,1,16); + // c[1, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p2,k2,1,32); + // c[1, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_1p3,k3,1,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_1p0 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_1p1 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_1p2 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_1p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(bfloat16,float,bf16of32_1x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_OPS_DISABLE, + &&POST_OPS_BIAS_1x64_OPS, + &&POST_OPS_RELU_1x64_OPS, + &&POST_OPS_RELU_SCALE_1x64_OPS, + &&POST_OPS_GELU_TANH_1x64_OPS, + &&POST_OPS_GELU_ERF_1x64_OPS, + &&POST_OPS_CLIP_1x64_OPS, + &&POST_OPS_DOWNSCALE_1x64_OPS, + &&POST_OPS_MATRIX_ADD_1x64_OPS, + &&POST_OPS_SWISH_1x64_OPS, + &&POST_OPS_MATRIX_MUL_1x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 selector1 = _mm512_setzero_ps(); + __m512 selector2 = _mm512_setzero_ps(); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + c_float_0p0 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ) ); + c_float_0p1 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ); + c_float_0p2 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ); + c_float_0p3 = CVT_BF16_F32_INT_SHIFT(_mm256_maskz_loadu_epi16( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_1x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k0, 0); + BF16_F32_BIAS_LOAD(selector2, k1, 1); + BF16_F32_BIAS_LOAD(selector3, k2, 2); + BF16_F32_BIAS_LOAD(selector4, k3, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64_OPS: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64_OPS: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x64_OPS: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k0, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64_OPS: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_OPS_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( post_ops_attr.c_stor_type == BF16 ) + { + // Actually the b matrix is of type bfloat16. However + // in order to reuse this kernel for f32, the output + // matrix type in kernel function signature is set to + // f32 irrespective of original output matrix type. + bfloat16* b_q = ( bfloat16* )b; + dim_t ir = 0; + + // Store the results in downscaled type (bf16 instead of float). + // c[0, 0-15] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p0,k0,0,0); + // c[0, 16-31] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p1,k1,0,16); + // c[0, 32-47] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p2,k2,0,32); + // c[0, 48-63] + CVT_STORE_F32_BF16_POST_OPS_MASK(c_float_0p3,k3,0,48); + } + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, c_float_0p0 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, c_float_0p1 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, c_float_0p2 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, c_float_0p3 ); + } + + post_ops_attr.post_op_c_j += NR_L; + } +} + +#endif //LPGEMM_BF16_JIT +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h index 484c2930eb..5146c19e90 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,15 +36,9 @@ #define LPGEMM_F32_KERN_MACROS_H #include "../gelu_avx512.h" +#include "../silu_avx512.h" #include "../math_utils_avx512.h" -// Disable BF16 kernel in cases where compilers support other avx 512 -// features except BF16 ISA. -#if ( defined( BLIS_GCC ) && ( ( __GNUC__ < 11 ) || \ - ( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) ) -#define LPGEMM_BF16_NOT_SUPPORTED -#endif - /* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */ #define RELU_SCALE_OP_F32_AVX512(reg) \ /* Generate indenx of elements <= 0.*/ \ @@ -79,7 +73,7 @@ F32_BETA_FMA(reg,scratch1,scratch2) \ // Default n < 16 mask load beta macro -#define F32_F32_BETA_OP_NLT16F_MASK(lmask,reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ +#define F32_F32_BETA_OP_NLT16F_MASK(c,lmask,reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ scratch1 = _mm512_maskz_loadu_ps( lmask, c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \ F32_BETA_FMA(reg,scratch1,scratch2) \ @@ -95,7 +89,10 @@ ) ), _mm512_set1_epi32 (16) ) );\ F32_BETA_FMA(reg,scratch1,scratch2) \ -#define MULRND_F32(reg,m_ind,n_ind) \ +// zero_point(avx512 register) contains bf16 zp upscaled to f32. +#define SCL_MULRND_F32(reg,selector,zero_point) \ + reg = _mm512_mul_ps( reg, selector ); \ + reg = _mm512_add_ps( reg, zero_point ); \ #define CVT_STORE_F32_BF16_MASK(reg,m_ind,n_ind) \ _mm256_mask_storeu_epi16 \ @@ -106,6 +103,49 @@ mask_all1, (__m256i) _mm512_cvtneps_pbh( reg ) \ ) \ +#define CVT_STORE_F32_BF16_POST_OPS_MASK(reg,mask,m_ind,n_ind) \ + _mm256_mask_storeu_epi16 \ + ( \ + b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \ + mask, (__m256i) _mm512_cvtneps_pbh( reg ) \ + ) \ + +// BF16 -> F32 convert helpers. reg: __m512 +#define CVT_BF16_F32_INT_SHIFT(in) \ + ( __m512 )_mm512_sllv_epi32( _mm512_cvtepi16_epi32( ( in ) ), \ + _mm512_set1_epi32( 16 ) ); + +// BF16 bias helper macros. +#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \ + scr = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_loadu_epi16 \ + ( \ + ( mask ), \ + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + +#define BF16_F32_BIAS_BCAST(scr,mask,m_ind) \ + scr = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_set1_epi16 \ + ( \ + ( mask ), \ + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \ + post_ops_attr.post_op_c_i + m_ind ) \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + /* TANH GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */ #define GELU_TANH_F32_AVX512(reg, r, r2, x, z, dn, x_tanh, q) \ \ @@ -120,4 +160,227 @@ \ reg = _mm512_min_ps( _mm512_max_ps( reg, min ), max ); \ +// Matrix Add post-ops helper macros +#define F32_MATRIX_ADD_1COL(scr0,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_add_ps( scr0, c_float_ ## m_ind ## p0 ); \ + +#define F32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_add_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_add_ps( scr1, c_float_ ## m_ind ## p1 ); \ + +#define F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_add_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_add_ps( scr1, c_float_ ## m_ind ## p1 ); \ + c_float_ ## m_ind ## p2 = _mm512_add_ps( scr2, c_float_ ## m_ind ## p2 ); \ + +#define F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_add_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_add_ps( scr1, c_float_ ## m_ind ## p1 ); \ + c_float_ ## m_ind ## p2 = _mm512_add_ps( scr2, c_float_ ## m_ind ## p2 ); \ + c_float_ ## m_ind ## p3 = _mm512_add_ps( scr3, c_float_ ## m_ind ## p3 ); \ + +#define BF16_F32_MATRIX_ADD_LOAD(mask,scr,m_ind,n_ind) \ + scr = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_loadu_epi16 \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + +#define BF16_F32_MATRIX_ADD_1COL_PAR(mask,scr0,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(mask,scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define BF16_F32_MATRIX_ADD_1COL(scr0,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define BF16_F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind); \ + +#define BF16_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,scr0,scr1,scr2,scr3,m_ind) \ + BF16_F32_MATRIX_ADD_LOAD(k0,scr0,m_ind,0); \ + BF16_F32_MATRIX_ADD_LOAD(k1,scr1,m_ind,1); \ + BF16_F32_MATRIX_ADD_LOAD(k2,scr2,m_ind,2); \ + BF16_F32_MATRIX_ADD_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define F32_F32_MATRIX_ADD_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_maskz_loadu_ps \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ); \ + +#define F32_F32_MATRIX_ADD_1COL_PAR(mask,scr0,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(mask,scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define F32_F32_MATRIX_ADD_1COL(scr0,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define F32_F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind); \ + +#define F32_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(k0,scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(k1,scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(k2,scr2,m_ind,2); \ + F32_F32_MATRIX_ADD_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +// Matrix mul post-ops helper macros +#define F32_MATRIX_MUL_1COL(scr0,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_mul_ps( scr0, c_float_ ## m_ind ## p0 ); \ + +#define F32_MATRIX_MUL_2COL(scr0,scr1,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_mul_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_mul_ps( scr1, c_float_ ## m_ind ## p1 ); \ + +#define F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_mul_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_mul_ps( scr1, c_float_ ## m_ind ## p1 ); \ + c_float_ ## m_ind ## p2 = _mm512_mul_ps( scr2, c_float_ ## m_ind ## p2 ); \ + +#define F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind) \ + c_float_ ## m_ind ## p0 = _mm512_mul_ps( scr0, c_float_ ## m_ind ## p0 ); \ + c_float_ ## m_ind ## p1 = _mm512_mul_ps( scr1, c_float_ ## m_ind ## p1 ); \ + c_float_ ## m_ind ## p2 = _mm512_mul_ps( scr2, c_float_ ## m_ind ## p2 ); \ + c_float_ ## m_ind ## p3 = _mm512_mul_ps( scr3, c_float_ ## m_ind ## p3 ); \ + +#define BF16_F32_MATRIX_MUL_LOAD(mask,scr,m_ind,n_ind) \ + scr = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_loadu_epi16 \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + +#define BF16_F32_MATRIX_MUL_1COL_PAR(mask,scr0,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(mask,scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL(scr0,m_ind); \ + +#define BF16_F32_MATRIX_MUL_1COL(scr0,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL(scr0,m_ind); \ + +#define BF16_F32_MATRIX_MUL_2COL(scr0,scr1,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_MUL_2COL(scr0,scr1,m_ind); \ + +#define BF16_F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind); \ + +#define BF16_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,scr0,scr1,scr2,scr3,m_ind) \ + BF16_F32_MATRIX_MUL_LOAD(k0,scr0,m_ind,0); \ + BF16_F32_MATRIX_MUL_LOAD(k1,scr1,m_ind,1); \ + BF16_F32_MATRIX_MUL_LOAD(k2,scr2,m_ind,2); \ + BF16_F32_MATRIX_MUL_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define F32_F32_MATRIX_MUL_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_maskz_loadu_ps \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ); \ + +#define F32_F32_MATRIX_MUL_1COL_PAR(mask,scr0,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(mask,scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL(scr0,m_ind); \ + +#define F32_F32_MATRIX_MUL_1COL(scr0,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_MATRIX_MUL_1COL(scr0,m_ind); \ + +#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_MUL_2COL(scr0,scr1,m_ind); \ + +#define F32_F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind); \ + +#define F32_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(k0,scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(k1,scr1,m_ind,1); \ + F32_F32_MATRIX_MUL_LOAD(k2,scr2,m_ind,2); \ + F32_F32_MATRIX_MUL_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +//Zero-out the given ZMM accumulator registers +#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \ + zmm0 = _mm512_setzero_ps(); \ + zmm1 = _mm512_setzero_ps(); \ + zmm2 = _mm512_setzero_ps(); \ + zmm3 = _mm512_setzero_ps(); + #endif // LPGEMM_F32_KERN_MACROS_H diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index 26f45c5101..2c271d1a1e 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ #include "lpgemm_f32_kern_macros.h" -#ifndef LPGEMM_BF16_NOT_SUPPORTED +#ifndef LPGEMM_BF16_JIT // 5x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) { @@ -53,7 +53,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) &&POST_OPS_GELU_TANH_5x64, &&POST_OPS_GELU_ERF_5x64, &&POST_OPS_CLIP_5x64, - &&POST_OPS_DOWNSCALE_5x64 + &&POST_OPS_DOWNSCALE_5x64, + &&POST_OPS_MATRIX_ADD_5x64, + &&POST_OPS_SWISH_5x64, + &&POST_OPS_MATRIX_MUL_5x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -403,18 +406,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -484,21 +498,34 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 2 ) ); - selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 4 ) ); + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -899,68 +926,431 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x64: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 48-63] - MULRND_F32(c_float_1p3,1,3); + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 48-63] - MULRND_F32(c_float_2p3,2,3); + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 48-63] - MULRND_F32(c_float_3p3,3,3); + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 32-47] - MULRND_F32(c_float_4p2,4,2); + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 48-63] - MULRND_F32(c_float_4p3,4,3); + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1040,7 +1430,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) } - // Case where the output C matrix is float + // Case where the output C matrix is float else { // Store the results. @@ -1119,7 +1509,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) &&POST_OPS_GELU_TANH_4x64, &&POST_OPS_GELU_ERF_4x64, &&POST_OPS_CLIP_4x64, - &&POST_OPS_DOWNSCALE_4x64 + &&POST_OPS_DOWNSCALE_4x64, + &&POST_OPS_MATRIX_ADD_4x64, + &&POST_OPS_SWISH_4x64, + &&POST_OPS_MATRIX_MUL_4x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1413,18 +1806,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1482,18 +1886,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 2 ) ); - selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 3 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1822,124 +2237,437 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x64: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } - // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); - // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); - // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); - // c[1, 48-63] - MULRND_F32(c_float_1p3,1,3); + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); - // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); - // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); - // c[2, 48-63] - MULRND_F32(c_float_2p3,2,3); + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); - // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); - // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); - // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); - // c[3, 48-63] - MULRND_F32(c_float_3p3,3,3); + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); -POST_OPS_4x64_DISABLE: - ; + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // Case where the output C matrix is bf16 (downscaled) and this is the - // final write for a given block within C. - if ( ( post_ops_attr.buf_downscale != NULL ) && - ( post_ops_attr.is_last_k == TRUE ) ) - { - // Generate a mask16 of all 1's. - __m512i selector_a = _mm512_setzero_epi32(); - __m512i selector_b = _mm512_set1_epi32( 10 ); - __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); - // Store the results in downscaled type (bf16 instead of float). + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); - // c[0, 0-15] - CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); - // c[0, 16-31] - CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); - // c[0, 32-47] - CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); - // c[0, 48-63] - CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); - // c[1, 0-15] - CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); - // c[1, 16-31] - CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); - // c[1, 32-47] - CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); - // c[1, 48-63] - CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); - // c[2, 0-15] - CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); - // c[2, 16-31] - CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); - // c[2, 32-47] - CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); - // c[2, 48-63] - CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); - // c[3, 0-15] - CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } - // c[3, 16-31] - CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[3, 32-47] - CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_4x64_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); // c[3, 48-63] CVT_STORE_F32_BF16_MASK(c_float_3p3,3,3); } - + // Case where the output C matrix is float else { @@ -2006,7 +2734,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) &&POST_OPS_GELU_TANH_3x64, &&POST_OPS_GELU_ERF_3x64, &&POST_OPS_CLIP_3x64, - &&POST_OPS_DOWNSCALE_3x64 + &&POST_OPS_DOWNSCALE_3x64, + &&POST_OPS_MATRIX_ADD_3x64, + &&POST_OPS_SWISH_3x64, + &&POST_OPS_MATRIX_MUL_3x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2242,18 +2973,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2299,15 +3041,25 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 2 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2564,44 +3316,314 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x64: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 48-63] - MULRND_F32(c_float_1p3,1,3); + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 48-63] - MULRND_F32(c_float_2p3,2,3); + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2655,7 +3677,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) // c[2, 48-63] CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); } - + // Case where the output C matrix is float else { @@ -2710,7 +3732,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) &&POST_OPS_GELU_TANH_2x64, &&POST_OPS_GELU_ERF_2x64, &&POST_OPS_CLIP_2x64, - &&POST_OPS_DOWNSCALE_2x64 + &&POST_OPS_DOWNSCALE_2x64, + &&POST_OPS_MATRIX_ADD_2x64, + &&POST_OPS_SWISH_2x64, + &&POST_OPS_MATRIX_MUL_2x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2888,18 +3913,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2933,12 +3969,21 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 1 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3123,32 +4168,259 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x64: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 48-63] - MULRND_F32(c_float_1p3,1,3); + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3191,7 +4463,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // c[1, 48-63] CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); } - + // Case where the output C matrix is float else { @@ -3234,7 +4506,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) &&POST_OPS_GELU_TANH_1x64, &&POST_OPS_GELU_ERF_1x64, &&POST_OPS_CLIP_1x64, - &&POST_OPS_DOWNSCALE_1x64 + &&POST_OPS_DOWNSCALE_1x64, + &&POST_OPS_MATRIX_ADD_1x64, + &&POST_OPS_SWISH_1x64, + &&POST_OPS_MATRIX_MUL_1x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3346,18 +4621,29 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3379,9 +4665,17 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // the ic index, and each bias element corresponds to an // entire row of the transposed output array, instead of an // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_i + 0 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3494,20 +4788,204 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x64: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 48-63] - MULRND_F32(c_float_0p3,0,3); + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3522,7 +5000,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) __m512i selector_a = _mm512_setzero_epi32(); __m512i selector_b = _mm512_set1_epi32( 10 ); __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); - + // Store the results in downscaled type (bf16 instead of float). // c[0, 0-15] @@ -3537,7 +5015,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // c[0, 48-63] CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); } - + // Case where the output C matrix is float else { diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16s4f32of32_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16s4f32of32_amd512vnni.c new file mode 100644 index 0000000000..d2fe6615ce --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16s4f32of32_amd512vnni.c @@ -0,0 +1,5524 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" +#include "../int4_utils_avx512.h" + +#ifndef LPGEMM_BF16_JIT +// 5x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_5x64) +{ + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_DISABLE, + &&POST_OPS_BIAS_5x64, + &&POST_OPS_RELU_5x64, + &&POST_OPS_RELU_SCALE_5x64, + &&POST_OPS_GELU_TANH_5x64, + &&POST_OPS_GELU_ERF_5x64, + &&POST_OPS_CLIP_5x64, + &&POST_OPS_DOWNSCALE_5x64, + &&POST_OPS_MATRIX_ADD_5x64, + &&POST_OPS_SWISH_5x64, + &&POST_OPS_MATRIX_MUL_5x64 + }; + + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // Broadcast a[0,kr:kr+4]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + } + + // Scale C by beta. + if ( beta != 0 ) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + BF16_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + BF16_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_3p3,0,3,3,selector1,selector2) + + // c[4,0-15] + BF16_F32_BETA_OP(c_float_4p0,0,4,0,selector1,selector2) + + // c[4,16-31] + BF16_F32_BETA_OP(c_float_4p1,0,4,1,selector1,selector2) + + // c[4,32-47] + BF16_F32_BETA_OP(c_float_4p2,0,4,2,selector1,selector2) + + // c[4,48-63] + BF16_F32_BETA_OP(c_float_4p3,0,4,3,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + F32_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + F32_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_3p3,0,3,3,selector1,selector2) + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0,0,4,0,selector1,selector2) + + // c[4,16-31] + F32_F32_BETA_OP(c_float_4p1,0,4,1,selector1,selector2) + + // c[4,32-47] + F32_F32_BETA_OP(c_float_4p2,0,4,2,selector1,selector2) + + // c[4,48-63] + F32_F32_BETA_OP(c_float_4p3,0,4,3,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32_AVX512(c_float_4p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32_AVX512(c_float_4p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + // c[4, 48-63] + CLIP_F32_AVX512(c_float_4p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[4, 48-63] + SCL_MULRND_F32(c_float_4p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(c_float_4p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_DISABLE: + ; + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_4p3,4,3); + + } + + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 3*16 ), c_float_4p3 ); + + } +} + +// 4x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_4x64) +{ + + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_DISABLE, + &&POST_OPS_BIAS_4x64, + &&POST_OPS_RELU_4x64, + &&POST_OPS_RELU_SCALE_4x64, + &&POST_OPS_GELU_TANH_4x64, + &&POST_OPS_GELU_ERF_4x64, + &&POST_OPS_CLIP_4x64, + &&POST_OPS_DOWNSCALE_4x64, + &&POST_OPS_MATRIX_ADD_4x64, + &&POST_OPS_SWISH_4x64, + &&POST_OPS_MATRIX_MUL_4x64 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + BF16_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + BF16_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_3p3,0,3,3,selector1,selector2) + + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + F32_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + F32_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_3p3,0,3,3,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32_AVX512(c_float_3p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32_AVX512(c_float_3p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[3, 48-63] + CLIP_F32_AVX512(c_float_3p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + // c[3, 48-63] + SCL_MULRND_F32(c_float_3p3,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(c_float_3p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_4x64_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_3p3,3,3); + } + + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); + } +} + +// 3x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_3x64) +{ + + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_DISABLE, + &&POST_OPS_BIAS_3x64, + &&POST_OPS_RELU_3x64, + &&POST_OPS_RELU_SCALE_3x64, + &&POST_OPS_GELU_TANH_3x64, + &&POST_OPS_GELU_ERF_3x64, + &&POST_OPS_CLIP_3x64, + &&POST_OPS_DOWNSCALE_3x64, + &&POST_OPS_MATRIX_ADD_3x64, + &&POST_OPS_SWISH_3x64, + &&POST_OPS_MATRIX_MUL_3x64 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + BF16_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + BF16_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + F32_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[2,48-63] + F32_F32_BETA_OP(c_float_2p3,0,2,3,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32_AVX512(c_float_2p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32_AVX512(c_float_2p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[2, 48-63] + CLIP_F32_AVX512(c_float_2p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[2, 48-63] + SCL_MULRND_F32(c_float_2p3,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(c_float_2p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_DISABLE: + ; + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_2p3,2,3); + } + + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + } +} + +// 2x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_2x64) +{ + dim_t pre_op_off = post_ops_attr.pre_op_off; + + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_DISABLE, + &&POST_OPS_BIAS_2x64, + &&POST_OPS_RELU_2x64, + &&POST_OPS_RELU_SCALE_2x64, + &&POST_OPS_GELU_TANH_2x64, + &&POST_OPS_GELU_ERF_2x64, + &&POST_OPS_CLIP_2x64, + &&POST_OPS_DOWNSCALE_2x64, + &&POST_OPS_MATRIX_ADD_2x64, + &&POST_OPS_SWISH_2x64, + &&POST_OPS_MATRIX_MUL_2x64 + }; + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_1 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + BF16_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[1,48-63] + F32_F32_BETA_OP(c_float_1p3,0,1,3,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32_AVX512(c_float_1p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32_AVX512(c_float_1p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[1, 48-63] + CLIP_F32_AVX512(c_float_1p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[1, 48-63] + SCL_MULRND_F32(c_float_1p3,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(c_float_1p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_DISABLE: + ; + + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_1p3,1,3); + } + + // Case where the output C matrix is float + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + } +} + +// 1x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_1x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_DISABLE, + &&POST_OPS_BIAS_1x64, + &&POST_OPS_RELU_1x64, + &&POST_OPS_RELU_SCALE_1x64, + &&POST_OPS_GELU_TANH_1x64, + &&POST_OPS_GELU_ERF_1x64, + &&POST_OPS_CLIP_1x64, + &&POST_OPS_DOWNSCALE_1x64, + &&POST_OPS_MATRIX_ADD_1x64, + &&POST_OPS_SWISH_1x64, + &&POST_OPS_MATRIX_MUL_1x64 + }; + + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + __m256i b0_s4; + __m256i b1_s4; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8, b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5, scale6, scale7; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + scale6 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 48 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + scale7 = _mm512_permutex2var_ps( scale6, mask_scale2, scale6 ); + scale6 = _mm512_permutex2var_ps( scale6, mask_scale1, scale6 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale6 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale7 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Broadcast a[0,kr] + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_16( b1_s8, 0, scale4 ) ); + + b3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b1_s8, 3, scale7 ), + CVT_INT8_F32_SCAL_16( b1_s8, 2, scale6 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + } + + // Scale C by beta. + if ( beta != 0) + { + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + BF16_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[0,48-63] + F32_F32_BETA_OP(c_float_0p3,0,0,3,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + BF16_F32_BIAS_LOAD(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32_AVX512(c_float_0p3, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x64: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32_AVX512(c_float_0p3, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x64: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[0, 48-63] + CLIP_F32_AVX512(c_float_0p3, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(c_float_0p3,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x64: + { + __m512 selector3; + __m512 selector4; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + BF16_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(selector1,selector2,selector3,selector4,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(c_float_0p3, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_DISABLE: + ; + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_STORE_F32_BF16_MASK(c_float_0p3,0,3); + } + + // Case where the output C matrix is float + else + { + // Store the accumulated results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + } +} +#endif +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index f0d58752e4..9f71a1d4b1 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ #include "lpgemm_f32_kern_macros.h" -#ifndef LPGEMM_BF16_NOT_SUPPORTED +#ifndef LPGEMM_BF16_JIT // 5xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) { @@ -53,7 +53,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) &&POST_OPS_GELU_TANH_5xLT16, &&POST_OPS_GELU_ERF_5xLT16, &&POST_OPS_CLIP_5xLT16, - &&POST_OPS_DOWNSCALE_5xLT16 + &&POST_OPS_DOWNSCALE_5xLT16, + &&POST_OPS_MATRIX_ADD_5xLT16, + &&POST_OPS_SWISH_5xLT16, + &&POST_OPS_MATRIX_MUL_5xLT16, }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -66,9 +69,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - // Registers to use for accumulating C. __m512 c_float_0p0 = _mm512_setzero_ps(); @@ -216,23 +216,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, 0, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_1p0, 0, 1, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_2p0, 0, 2, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_3p0, 0, 3, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, 0, 3, 0, \ selector1, selector2); // c[4,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_4p0, 0, 4, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_4p0, 0, 4, 0, \ selector1, selector2); } } @@ -244,9 +244,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -265,21 +277,36 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -413,25 +440,280 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) POST_OPS_DOWNSCALE_5xLT16: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5xLT16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); @@ -487,7 +769,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) &&POST_OPS_GELU_TANH_4xLT16, &&POST_OPS_GELU_ERF_4xLT16, &&POST_OPS_CLIP_4xLT16, - &&POST_OPS_DOWNSCALE_4xLT16 + &&POST_OPS_DOWNSCALE_4xLT16, + &&POST_OPS_MATRIX_ADD_4xLT16, + &&POST_OPS_SWISH_4xLT16, + &&POST_OPS_MATRIX_MUL_4xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -500,9 +785,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - // Registers to use for accumulating C. __m512 c_float_0p0 = _mm512_setzero_ps(); @@ -582,7 +864,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); } - + // Load alpha and beta __m512 selector1 = _mm512_set1_ps( alpha ); __m512 selector2 = _mm512_set1_ps( beta ); @@ -628,19 +910,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, 0, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_1p0, 0, 1, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_2p0, 0, 2, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_3p0, 0, 3, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, 0, 3, 0, \ selector1, selector2); } } @@ -652,9 +934,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -670,18 +964,31 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -794,25 +1101,247 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4xLT16: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xLT16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); @@ -862,7 +1391,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) &&POST_OPS_GELU_TANH_3xLT16, &&POST_OPS_GELU_ERF_3xLT16, &&POST_OPS_CLIP_3xLT16, - &&POST_OPS_DOWNSCALE_3xLT16 + &&POST_OPS_DOWNSCALE_3xLT16, + &&POST_OPS_MATRIX_ADD_3xLT16, + &&POST_OPS_SWISH_3xLT16, + &&POST_OPS_MATRIX_MUL_3xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -875,9 +1407,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - // Registers to use for accumulating C. __m512 c_float_0p0 = _mm512_setzero_ps(); @@ -979,15 +1508,15 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, 0, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_1p0, 0, 1, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_2p0, 0, 2, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ selector1, selector2); } } @@ -999,9 +1528,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1014,15 +1555,26 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1120,39 +1672,230 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) POST_OPS_DOWNSCALE_3xLT16: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3xLT16_DISABLE: - ; - if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) - { - __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - // Store the results in downscaled type (int8 instead of int32). - // c[0,0-15] - CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } - // c[1,0-15] - CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } - // c[2,0-15] - CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); - } + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } - else - { - __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // Store the results. + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. // c[0,0-15] _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); @@ -1177,7 +1920,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) &&POST_OPS_GELU_TANH_2xLT16, &&POST_OPS_GELU_ERF_2xLT16, &&POST_OPS_CLIP_2xLT16, - &&POST_OPS_DOWNSCALE_2xLT16 + &&POST_OPS_DOWNSCALE_2xLT16, + &&POST_OPS_MATRIX_ADD_2xLT16, + &&POST_OPS_SWISH_2xLT16, + &&POST_OPS_MATRIX_MUL_2xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1190,9 +1936,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - // Registers to use for accumulating C. __m512 c_float_0p0 = _mm512_setzero_ps(); @@ -1271,11 +2014,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, 0, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_1p0, 0, 1, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ selector1, selector2); } } @@ -1287,9 +2030,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1299,12 +2054,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1384,16 +2148,174 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) POST_OPS_DOWNSCALE_2xLT16: { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xLT16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); @@ -1432,7 +2354,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) &&POST_OPS_GELU_TANH_1xLT16, &&POST_OPS_GELU_ERF_1xLT16, &&POST_OPS_CLIP_1xLT16, - &&POST_OPS_DOWNSCALE_1xLT16 + &&POST_OPS_DOWNSCALE_1xLT16, + &&POST_OPS_MATRIX_ADD_1xLT16, + &&POST_OPS_SWISH_1xLT16, + &&POST_OPS_MATRIX_MUL_1xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1445,9 +2370,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - // Registers to use for accumulating C. __m512 c_float_0p0 = _mm512_setzero_ps(); @@ -1503,7 +2425,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, 0, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ selector1, selector2); } } @@ -1515,18 +2437,38 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1588,13 +2530,140 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) POST_OPS_DOWNSCALE_1xLT16: { + __m512 zero_point0 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xLT16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); @@ -1627,7 +2696,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) &&POST_OPS_GELU_TANH_5x16, &&POST_OPS_GELU_ERF_5x16, &&POST_OPS_CLIP_5x16, - &&POST_OPS_DOWNSCALE_5x16 + &&POST_OPS_DOWNSCALE_5x16, + &&POST_OPS_MATRIX_ADD_5x16, + &&POST_OPS_SWISH_5x16, + &&POST_OPS_MATRIX_MUL_5x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1812,9 +2884,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1833,31 +2913,46 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + // c[3,0-15] c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); @@ -1978,28 +3073,280 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x16: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -2056,7 +3403,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) &&POST_OPS_GELU_TANH_4x16, &&POST_OPS_GELU_ERF_4x16, &&POST_OPS_CLIP_4x16, - &&POST_OPS_DOWNSCALE_4x16 + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16, + &&POST_OPS_MATRIX_MUL_4x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2214,9 +3564,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2232,18 +3590,31 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2356,25 +3727,245 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x16: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -2394,7 +3985,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // c[3,0-15] CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); - } + } else { @@ -2425,7 +4016,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) &&POST_OPS_GELU_TANH_3x16, &&POST_OPS_GELU_ERF_3x16, &&POST_OPS_CLIP_3x16, - &&POST_OPS_DOWNSCALE_3x16 + &&POST_OPS_DOWNSCALE_3x16, + &&POST_OPS_MATRIX_ADD_3x16, + &&POST_OPS_SWISH_3x16, + &&POST_OPS_MATRIX_MUL_3x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2547,7 +4141,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) F32_F32_BETA_OP(c_float_2p0, 0, 2, 0, \ selector1, selector2); } - + } // Post Ops lpgemm_post_op* post_ops_list_temp = post_ops_list; @@ -2557,9 +4151,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2572,26 +4174,37 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - } - + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_RELU_3x16: @@ -2675,22 +4288,210 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x16: { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -2707,7 +4508,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) // c[2,0-15] CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); - } + } else { @@ -2735,7 +4536,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) &&POST_OPS_GELU_TANH_2x16, &&POST_OPS_GELU_ERF_2x16, &&POST_OPS_CLIP_2x16, - &&POST_OPS_DOWNSCALE_2x16 + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16, + &&POST_OPS_MATRIX_MUL_2x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2830,7 +4634,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) F32_F32_BETA_OP(c_float_1p0, 0, 1, 0, \ selector1, selector2); } - + } // Post Ops lpgemm_post_op* post_ops_list_temp = post_ops_list; @@ -2840,9 +4644,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2852,12 +4664,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2937,16 +4758,172 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) POST_OPS_DOWNSCALE_2x16: { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -2960,7 +4937,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) // c[1,0-15] CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); - } + } else { @@ -2985,7 +4962,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) &&POST_OPS_GELU_TANH_1x16, &&POST_OPS_GELU_ERF_1x16, &&POST_OPS_CLIP_1x16, - &&POST_OPS_DOWNSCALE_1x16 + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16, + &&POST_OPS_MATRIX_MUL_1x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3053,7 +5033,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ selector1, selector2); } - + } // Post Ops lpgemm_post_op* post_ops_list_temp = post_ops_list; @@ -3063,18 +5043,34 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3133,16 +5129,140 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x16: { + __m512 zero_point0 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -3153,7 +5273,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) // Store the results in downscaled type (int8 instead of int32). // c[0,0-15] CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); - } + } else { // Store the results. @@ -3174,7 +5294,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) &&POST_OPS_GELU_TANH_5x32, &&POST_OPS_GELU_ERF_5x32, &&POST_OPS_CLIP_5x32, - &&POST_OPS_DOWNSCALE_5x32 + &&POST_OPS_DOWNSCALE_5x32, + &&POST_OPS_MATRIX_ADD_5x32, + &&POST_OPS_SWISH_5x32, + &&POST_OPS_MATRIX_MUL_5x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3361,7 +5484,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // c[4, 16-31] BF16_F32_BETA_OP( c_float_4p1, 0, 4, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); @@ -3402,12 +5525,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3441,21 +5573,36 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3676,43 +5823,329 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x32: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x32_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -3799,7 +6232,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) &&POST_OPS_GELU_TANH_4x32, &&POST_OPS_GELU_ERF_4x32, &&POST_OPS_CLIP_4x32, - &&POST_OPS_DOWNSCALE_4x32 + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32, + &&POST_OPS_MATRIX_MUL_4x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3957,7 +6393,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // c[3, 16-31] BF16_F32_BETA_OP( c_float_3p1, 0, 3, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); @@ -3992,12 +6428,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -4025,18 +6470,31 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -4224,34 +6682,283 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) POST_OPS_DOWNSCALE_4x32: { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -4283,7 +6990,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // c[3, 16-31] CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); - } + } else { @@ -4326,7 +7033,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) &&POST_OPS_GELU_TANH_3x32, &&POST_OPS_GELU_ERF_3x32, &&POST_OPS_CLIP_3x32, - &&POST_OPS_DOWNSCALE_3x32 + &&POST_OPS_DOWNSCALE_3x32, + &&POST_OPS_MATRIX_ADD_3x32, + &&POST_OPS_SWISH_3x32, + &&POST_OPS_MATRIX_MUL_3x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4455,7 +7165,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) // c[2, 16-31] BF16_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); @@ -4484,12 +7194,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -4511,15 +7230,26 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -4621,78 +7351,289 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) } POST_OPS_GELU_ERF_3x32: { - __m512 x, r, x_erf; + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_3x32: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[0, 0-15] - GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); - // c[0, 16-31] - GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); - // c[1, 0-15] - GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[1, 16-31] - GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); - // c[2, 0-15] - GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); - // c[2, 16-31] - GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_CLIP_3x32: +POST_OPS_MATRIX_MUL_3x32: { - __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); - __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[0, 0-15] - CLIP_F32_AVX512(c_float_0p0, min, max) + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); - // c[0, 16-31] - CLIP_F32_AVX512(c_float_0p1, min, max) + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); - // c[1, 0-15] - CLIP_F32_AVX512(c_float_1p0, min, max) + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[1, 16-31] - CLIP_F32_AVX512(c_float_1p1, min, max) + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); - // c[2, 0-15] - CLIP_F32_AVX512(c_float_2p0, min, max) + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); - // c[2, 16-31] - CLIP_F32_AVX512(c_float_2p1, min, max) + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - -POST_OPS_DOWNSCALE_3x32: +POST_OPS_SWISH_3x32: { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x32_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -4755,7 +7696,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) &&POST_OPS_GELU_TANH_2x32, &&POST_OPS_GELU_ERF_2x32, &&POST_OPS_CLIP_2x32, - &&POST_OPS_DOWNSCALE_2x32 + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32, + &&POST_OPS_MATRIX_MUL_2x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4855,7 +7799,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) // c[1, 16-31] BF16_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); @@ -4878,12 +7822,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -4899,12 +7852,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -5020,22 +7982,194 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) POST_OPS_DOWNSCALE_2x32: { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -5085,7 +8219,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) &&POST_OPS_GELU_TANH_1x32, &&POST_OPS_GELU_ERF_1x32, &&POST_OPS_CLIP_1x32, - &&POST_OPS_DOWNSCALE_1x32 + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32, + &&POST_OPS_MATRIX_MUL_1x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -5156,7 +8293,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) // c[0, 16-31] BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); @@ -5173,12 +8310,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -5188,9 +8334,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -5208,78 +8362,219 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) // c[0,0-15] c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_1x32: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1x32: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } - __mmask16 relu_cmp_mask; + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[0, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_0p1) + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_GELU_TANH_1x32: +POST_OPS_MATRIX_ADD_1x32: { - __m512 dn, z, x, r2, r, x_tanh; - __m512i q; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[0, 0-15] - GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[0, 16-31] - GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_GELU_ERF_1x32: +POST_OPS_MATRIX_MUL_1x32: { - __m512 x, r, x_erf; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[0, 0-15] - GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[0, 16-31] - GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_CLIP_1x32: +POST_OPS_SWISH_1x32: { - __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); - __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - - // c[0, 0-15] - CLIP_F32_AVX512(c_float_0p0, min, max) - - // c[0, 16-31] - CLIP_F32_AVX512(c_float_0p1, min, max) + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } + __m512 al_in, r, r2, z, dn; + __m512i ex_out; -POST_OPS_DOWNSCALE_1x32: - { // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { // Generate a mask16 of all 1's. @@ -5318,7 +8613,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) &&POST_OPS_GELU_TANH_5x48, &&POST_OPS_GELU_ERF_5x48, &&POST_OPS_CLIP_5x48, - &&POST_OPS_DOWNSCALE_5x48 + &&POST_OPS_DOWNSCALE_5x48, + &&POST_OPS_MATRIX_ADD_5x48, + &&POST_OPS_SWISH_5x48, + &&POST_OPS_MATRIX_MUL_5x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -5600,15 +8898,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -5657,21 +8965,35 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -5935,105 +9257,430 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - // c[0, 0-15] - CLIP_F32_AVX512(c_float_0p0, min, max) + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_5x48: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); - // c[0, 16-31] - CLIP_F32_AVX512(c_float_0p1, min, max) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); - // c[0, 32-47] - CLIP_F32_AVX512(c_float_0p2, min, max) + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); - // c[1, 0-15] - CLIP_F32_AVX512(c_float_1p0, min, max) + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); - // c[1, 16-31] - CLIP_F32_AVX512(c_float_1p1, min, max) + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + } - // c[1, 32-47] - CLIP_F32_AVX512(c_float_1p2, min, max) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[2, 0-15] - CLIP_F32_AVX512(c_float_2p0, min, max) + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[2, 16-31] - CLIP_F32_AVX512(c_float_2p1, min, max) + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - // c[2, 32-47] - CLIP_F32_AVX512(c_float_2p2, min, max) + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); - // c[3, 0-15] - CLIP_F32_AVX512(c_float_3p0, min, max) + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); - // c[3, 16-31] - CLIP_F32_AVX512(c_float_3p1, min, max) + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[3, 32-47] - CLIP_F32_AVX512(c_float_3p2, min, max) + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[4, 0-15] - CLIP_F32_AVX512(c_float_4p0, min, max) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - // c[4, 16-31] - CLIP_F32_AVX512(c_float_4p1, min, max) + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); - // c[4, 32-47] - CLIP_F32_AVX512(c_float_4p2, min, max) + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - -POST_OPS_DOWNSCALE_5x48: +POST_OPS_SWISH_5x48: { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); // c[4, 32-47] - MULRND_F32(c_float_4p2,4,2); + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x48_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { @@ -6041,7 +9688,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) __m512i selector_a = _mm512_setzero_epi32(); __m512i selector_b = _mm512_set1_epi32( 10 ); __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); - + // Store the results in downscaled type (bf16 instead of float). // c[0, 0-15] @@ -6152,7 +9799,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) &&POST_OPS_GELU_TANH_4x48, &&POST_OPS_GELU_ERF_4x48, &&POST_OPS_CLIP_4x48, - &&POST_OPS_DOWNSCALE_4x48 + &&POST_OPS_DOWNSCALE_4x48, + &&POST_OPS_MATRIX_ADD_4x48, + &&POST_OPS_SWISH_4x48, + &&POST_OPS_MATRIX_MUL_4x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -6389,15 +10039,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -6437,18 +10097,30 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -6582,172 +10254,453 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) __m512 dn, z, x, r2, r, x_tanh; __m512i q; - // c[0, 0-15] - GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x48: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } - // c[0, 16-31] - GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[0, 32-47] - GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); - // c[1, 0-15] - GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); - // c[1, 16-31] - GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); - // c[1, 32-47] - GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); - // c[2, 0-15] - GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); - // c[2, 16-31] - GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); - // c[2, 32-47] - GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); - // c[3, 0-15] - GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); - // c[3, 16-31] - GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); - // c[3, 32-47] - GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_GELU_ERF_4x48: +POST_OPS_MATRIX_ADD_4x48: { - __m512 x, r, x_erf; - - // c[0, 0-15] - GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) - - // c[0, 16-31] - GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) - - // c[0, 32-47] - GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) - - // c[1, 0-15] - GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[1, 16-31] - GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); - // c[1, 32-47] - GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); - // c[2, 0-15] - GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); - // c[2, 16-31] - GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[2, 32-47] - GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); - // c[3, 0-15] - GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); - // c[3, 16-31] - GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); - // c[3, 32-47] - GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_CLIP_4x48: +POST_OPS_MATRIX_MUL_4x48: { - __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); - __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - - // c[0, 0-15] - CLIP_F32_AVX512(c_float_0p0, min, max) - - // c[0, 16-31] - CLIP_F32_AVX512(c_float_0p1, min, max) - - // c[0, 32-47] - CLIP_F32_AVX512(c_float_0p2, min, max) - - // c[1, 0-15] - CLIP_F32_AVX512(c_float_1p0, min, max) + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[1, 16-31] - CLIP_F32_AVX512(c_float_1p1, min, max) + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[1, 32-47] - CLIP_F32_AVX512(c_float_1p2, min, max) + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - // c[2, 0-15] - CLIP_F32_AVX512(c_float_2p0, min, max) + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); - // c[2, 16-31] - CLIP_F32_AVX512(c_float_2p1, min, max) + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[2, 32-47] - CLIP_F32_AVX512(c_float_2p2, min, max) + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[3, 0-15] - CLIP_F32_AVX512(c_float_3p0, min, max) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - // c[3, 16-31] - CLIP_F32_AVX512(c_float_3p1, min, max) + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); - // c[3, 32-47] - CLIP_F32_AVX512(c_float_3p2, min, max) + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - -POST_OPS_DOWNSCALE_4x48: +POST_OPS_SWISH_4x48: { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x48_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { @@ -6848,7 +10801,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) &&POST_OPS_GELU_TANH_3x48, &&POST_OPS_GELU_ERF_3x48, &&POST_OPS_CLIP_3x48, - &&POST_OPS_DOWNSCALE_3x48 + &&POST_OPS_DOWNSCALE_3x48, + &&POST_OPS_MATRIX_ADD_3x48, + &&POST_OPS_SWISH_3x48, + &&POST_OPS_MATRIX_MUL_3x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -7040,15 +10996,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -7079,15 +11045,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -7261,69 +11237,307 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - // c[0, 0-15] - CLIP_F32_AVX512(c_float_0p0, min, max) + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_3x48: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); - // c[0, 16-31] - CLIP_F32_AVX512(c_float_0p1, min, max) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); - // c[0, 32-47] - CLIP_F32_AVX512(c_float_0p2, min, max) + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + } - // c[1, 0-15] - CLIP_F32_AVX512(c_float_1p0, min, max) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; - // c[1, 16-31] - CLIP_F32_AVX512(c_float_1p1, min, max) + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[1, 32-47] - CLIP_F32_AVX512(c_float_1p2, min, max) + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - // c[2, 0-15] - CLIP_F32_AVX512(c_float_2p0, min, max) + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; - // c[2, 16-31] - CLIP_F32_AVX512(c_float_2p1, min, max) + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); - // c[2, 32-47] - CLIP_F32_AVX512(c_float_2p2, min, max) + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + } -POST_OPS_DOWNSCALE_3x48: + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x48: { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x48_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { @@ -7406,7 +11620,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) &&POST_OPS_GELU_TANH_2x48, &&POST_OPS_GELU_ERF_2x48, &&POST_OPS_CLIP_2x48, - &&POST_OPS_DOWNSCALE_2x48 + &&POST_OPS_DOWNSCALE_2x48, + &&POST_OPS_MATRIX_ADD_2x48, + &&POST_OPS_SWISH_2x48, + &&POST_OPS_MATRIX_MUL_2x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -7553,15 +11770,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -7583,12 +11810,21 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -7737,31 +11973,231 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x48: { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x48_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { @@ -7826,7 +12262,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) &&POST_OPS_GELU_TANH_1x48, &&POST_OPS_GELU_ERF_1x48, &&POST_OPS_CLIP_1x48, - &&POST_OPS_DOWNSCALE_1x48 + &&POST_OPS_DOWNSCALE_1x48, + &&POST_OPS_MATRIX_ADD_1x48, + &&POST_OPS_SWISH_1x48, + &&POST_OPS_MATRIX_MUL_1x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -7928,15 +12367,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -7949,9 +12398,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -8034,7 +12491,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) { __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - + // c[0, 0-15] CLIP_F32_AVX512(c_float_0p0, min, max) @@ -8049,19 +12506,183 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) POST_OPS_DOWNSCALE_1x48: { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x48_DISABLE: ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16s4f32of32_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16s4f32of32_amd512vnni.c new file mode 100644 index 0000000000..91e6c32bbd --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16s4f32of32_amd512vnni.c @@ -0,0 +1,14085 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" +#include "../int4_utils_avx512.h" + +#ifndef LPGEMM_BF16_JIT +// 5xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_5xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5xLT16_DISABLE, + &&POST_OPS_BIAS_5xLT16, + &&POST_OPS_RELU_5xLT16, + &&POST_OPS_RELU_SCALE_5xLT16, + &&POST_OPS_GELU_TANH_5xLT16, + &&POST_OPS_GELU_ERF_5xLT16, + &&POST_OPS_CLIP_5xLT16, + &&POST_OPS_DOWNSCALE_5xLT16, + &&POST_OPS_MATRIX_ADD_5xLT16, + &&POST_OPS_SWISH_5xLT16, + &&POST_OPS_MATRIX_MUL_5xLT16, + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_1p0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_2p0, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_3p0, 3, 0, \ + selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_4p0, 4, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, 0, 3, 0, \ + selector1, selector2); + + // c[4,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_4p0, 0, 4, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_5xLT16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 1 ), load_mask, c_float_1p0 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 2 ), load_mask, c_float_2p0 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 3 ), load_mask, c_float_3p0 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 4 ), load_mask, c_float_4p0 ); + } +} + +// 4xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_4xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4xLT16_DISABLE, + &&POST_OPS_BIAS_4xLT16, + &&POST_OPS_RELU_4xLT16, + &&POST_OPS_RELU_SCALE_4xLT16, + &&POST_OPS_GELU_TANH_4xLT16, + &&POST_OPS_GELU_ERF_4xLT16, + &&POST_OPS_CLIP_4xLT16, + &&POST_OPS_DOWNSCALE_4xLT16, + &&POST_OPS_MATRIX_ADD_4xLT16, + &&POST_OPS_SWISH_4xLT16, + &&POST_OPS_MATRIX_MUL_4xLT16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_1p0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_2p0, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_3p0, 3, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, 0, 3, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4xLT16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 1 ), load_mask, c_float_1p0 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 2 ), load_mask, c_float_2p0 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 3 ), load_mask, c_float_3p0 ); + } + +} + +// 3xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_3xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3xLT16_DISABLE, + &&POST_OPS_BIAS_3xLT16, + &&POST_OPS_RELU_3xLT16, + &&POST_OPS_RELU_SCALE_3xLT16, + &&POST_OPS_GELU_TANH_3xLT16, + &&POST_OPS_GELU_ERF_3xLT16, + &&POST_OPS_CLIP_3xLT16, + &&POST_OPS_DOWNSCALE_3xLT16, + &&POST_OPS_MATRIX_ADD_3xLT16, + &&POST_OPS_SWISH_3xLT16, + &&POST_OPS_MATRIX_MUL_3xLT16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_1p0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_2p0, 2, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, 0, 2, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_3xLT16: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 1 ), load_mask, c_float_1p0 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 2 ), load_mask, c_float_2p0 ); + } + +} + +// 2xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_2xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2xLT16_DISABLE, + &&POST_OPS_BIAS_2xLT16, + &&POST_OPS_RELU_2xLT16, + &&POST_OPS_RELU_SCALE_2xLT16, + &&POST_OPS_GELU_TANH_2xLT16, + &&POST_OPS_GELU_ERF_2xLT16, + &&POST_OPS_CLIP_2xLT16, + &&POST_OPS_DOWNSCALE_2xLT16, + &&POST_OPS_MATRIX_ADD_2xLT16, + &&POST_OPS_SWISH_2xLT16, + &&POST_OPS_MATRIX_MUL_2xLT16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_1p0, 1, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, 0, 1, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_2xLT16: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 1 ), load_mask, c_float_1p0 ); + } + +} + +// 1xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_1xlt16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1xLT16_DISABLE, + &&POST_OPS_BIAS_1xLT16, + &&POST_OPS_RELU_1xLT16, + &&POST_OPS_RELU_SCALE_1xLT16, + &&POST_OPS_GELU_TANH_1xLT16, + &&POST_OPS_GELU_ERF_1xLT16, + &&POST_OPS_CLIP_1xLT16, + &&POST_OPS_DOWNSCALE_1xLT16, + &&POST_OPS_MATRIX_ADD_1xLT16, + &&POST_OPS_SWISH_1xLT16, + &&POST_OPS_MATRIX_MUL_1xLT16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, 0, 0, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_1xLT16: + { + __m512 zero_point0 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xLT16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * 0 ), load_mask, c_float_0p0 ); + } + +} + +// 5x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_5x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x16_DISABLE, + &&POST_OPS_BIAS_5x16, + &&POST_OPS_RELU_5x16, + &&POST_OPS_RELU_SCALE_5x16, + &&POST_OPS_GELU_TANH_5x16, + &&POST_OPS_GELU_ERF_5x16, + &&POST_OPS_CLIP_5x16, + &&POST_OPS_DOWNSCALE_5x16, + &&POST_OPS_MATRIX_ADD_5x16, + &&POST_OPS_SWISH_5x16, + &&POST_OPS_MATRIX_MUL_5x16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, 0, 3, 0, \ + selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP( c_float_4p0, 0, 4, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0, 0, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0, 0, 3, 0, \ + selector1, selector2); + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0, 0, 4, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + } +} + +// 4x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_4x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_GELU_TANH_4x16, + &&POST_OPS_GELU_ERF_4x16, + &&POST_OPS_CLIP_4x16, + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16, + &&POST_OPS_MATRIX_MUL_4x16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, 0, 3, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0, 0, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0, 0, 3, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + } +} + +// 3x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_3x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x16_DISABLE, + &&POST_OPS_BIAS_3x16, + &&POST_OPS_RELU_3x16, + &&POST_OPS_RELU_SCALE_3x16, + &&POST_OPS_GELU_TANH_3x16, + &&POST_OPS_GELU_ERF_3x16, + &&POST_OPS_CLIP_3x16, + &&POST_OPS_DOWNSCALE_3x16, + &&POST_OPS_MATRIX_ADD_3x16, + &&POST_OPS_SWISH_3x16, + &&POST_OPS_MATRIX_MUL_3x16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0, 0, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0, 0, 2, 0, \ + selector1, selector2); + } + + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x16: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + } +} + +// 2x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_2x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_GELU_TANH_2x16, + &&POST_OPS_GELU_ERF_2x16, + &&POST_OPS_CLIP_2x16, + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16, + &&POST_OPS_MATRIX_MUL_2x16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0, 0, 1, 0, \ + selector1, selector2); + } + + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_2x16: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + } +} + +// 1x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_1x16) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_GELU_TANH_1x16, + &&POST_OPS_GELU_ERF_1x16, + &&POST_OPS_CLIP_1x16, + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16, + &&POST_OPS_MATRIX_MUL_1x16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, 0, 0, 0, \ + selector1, selector2); + } + + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + __m512 zero_point0 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + } + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + } +} + +// 5x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_5x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x32_DISABLE, + &&POST_OPS_BIAS_5x32, + &&POST_OPS_RELU_5x32, + &&POST_OPS_RELU_SCALE_5x32, + &&POST_OPS_GELU_TANH_5x32, + &&POST_OPS_GELU_ERF_5x32, + &&POST_OPS_CLIP_5x32, + &&POST_OPS_DOWNSCALE_5x32, + &&POST_OPS_MATRIX_ADD_5x32, + &&POST_OPS_SWISH_5x32, + &&POST_OPS_MATRIX_MUL_5x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta );\ + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + BF16_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + BF16_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, 0, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + BF16_F32_BETA_OP( c_float_3p1, 0, 3, 1, selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP( c_float_4p0, 0, 4, 0, selector1, selector2 ); + + // c[4, 16-31] + BF16_F32_BETA_OP( c_float_4p1, 0, 4, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + F32_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + F32_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + F32_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + F32_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + + // c[3,0-15] + F32_F32_BETA_OP( c_float_3p0, 0, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + F32_F32_BETA_OP( c_float_3p1, 0, 3, 1, selector1, selector2 ); + + // c[4,0-15] + F32_F32_BETA_OP( c_float_4p0, 0, 4, 0, selector1, selector2 ); + + // c[4, 16-31] + F32_F32_BETA_OP( c_float_4p1, 0, 4, 1, selector1, selector2 ); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x32: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector5,zero_point4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x32_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + } +} + +// 4x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_4x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_GELU_TANH_4x32, + &&POST_OPS_GELU_ERF_4x32, + &&POST_OPS_CLIP_4x32, + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32, + &&POST_OPS_MATRIX_MUL_4x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + BF16_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + BF16_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, 0, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + BF16_F32_BETA_OP( c_float_3p1, 0, 3, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + F32_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + F32_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + F32_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + F32_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + + // c[3,0-15] + F32_F32_BETA_OP( c_float_3p0, 0, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + F32_F32_BETA_OP( c_float_3p1, 0, 3, 1, selector1, selector2 ); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + } + else + { + __m512 selector3; + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_4x32: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + } +} + +// 3x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_3x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x32_DISABLE, + &&POST_OPS_BIAS_3x32, + &&POST_OPS_RELU_3x32, + &&POST_OPS_RELU_SCALE_3x32, + &&POST_OPS_GELU_TANH_3x32, + &&POST_OPS_GELU_ERF_3x32, + &&POST_OPS_CLIP_3x32, + &&POST_OPS_DOWNSCALE_3x32, + &&POST_OPS_MATRIX_ADD_3x32, + &&POST_OPS_SWISH_3x32, + &&POST_OPS_MATRIX_MUL_3x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + BF16_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + BF16_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + F32_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + F32_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + + // c[2,0-15] + F32_F32_BETA_OP( c_float_2p0, 0, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + F32_F32_BETA_OP( c_float_2p1, 0, 2, 1, selector1, selector2 ); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + } + else + { + __m512 selector3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_3x32: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x32_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + } +} + +// 2x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_2x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_GELU_TANH_2x32, + &&POST_OPS_GELU_ERF_2x32, + &&POST_OPS_CLIP_2x32, + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32, + &&POST_OPS_MATRIX_MUL_2x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + BF16_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + + // c[1,0-15] + F32_F32_BETA_OP( c_float_1p0, 0, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + F32_F32_BETA_OP( c_float_1p1, 0, 1, 1, selector1, selector2 ); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_2x32: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + } + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + } +} + +// 1x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_1x32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_GELU_TANH_1x32, + &&POST_OPS_GELU_ERF_1x32, + &&POST_OPS_CLIP_1x32, + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32, + &&POST_OPS_MATRIX_MUL_1x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, 0, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, 0, 0, 1, selector1, selector2 ); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_1x32: + { + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + } +} + +// 5x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_5x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x48_DISABLE, + &&POST_OPS_BIAS_5x48, + &&POST_OPS_RELU_5x48, + &&POST_OPS_RELU_SCALE_5x48, + &&POST_OPS_GELU_TANH_5x48, + &&POST_OPS_GELU_ERF_5x48, + &&POST_OPS_CLIP_5x48, + &&POST_OPS_DOWNSCALE_5x48, + &&POST_OPS_MATRIX_ADD_5x48, + &&POST_OPS_SWISH_5x48, + &&POST_OPS_MATRIX_MUL_5x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[4,0-15] + BF16_F32_BETA_OP(c_float_4p0,0,4,0,selector1,selector2) + + // c[4,16-31] + BF16_F32_BETA_OP(c_float_4p1,0,4,1,selector1,selector2) + + // c[4,32-47] + BF16_F32_BETA_OP(c_float_4p2,0,4,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0,0,4,0,selector1,selector2) + + // c[4,16-31] + F32_F32_BETA_OP(c_float_4p1,0,4,1,selector1,selector2) + + // c[4,32-47] + F32_F32_BETA_OP(c_float_4p2,0,4,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + } + else + { + __m512 selector4; + __m512 selector5; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_5x48: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x48_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_4p2,4,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); + } +} + +// 4x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_4x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x48_DISABLE, + &&POST_OPS_BIAS_4x48, + &&POST_OPS_RELU_4x48, + &&POST_OPS_RELU_SCALE_4x48, + &&POST_OPS_GELU_TANH_4x48, + &&POST_OPS_GELU_ERF_4x48, + &&POST_OPS_CLIP_4x48, + &&POST_OPS_DOWNSCALE_4x48, + &&POST_OPS_MATRIX_ADD_4x48, + &&POST_OPS_SWISH_4x48, + &&POST_OPS_MATRIX_MUL_4x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,0,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,0,3,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + } + else + { + __m512 selector4; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x48: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x48_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + } +} + +// 3x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_3x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x48_DISABLE, + &&POST_OPS_BIAS_3x48, + &&POST_OPS_RELU_3x48, + &&POST_OPS_RELU_SCALE_3x48, + &&POST_OPS_GELU_TANH_3x48, + &&POST_OPS_GELU_ERF_3x48, + &&POST_OPS_CLIP_3x48, + &&POST_OPS_DOWNSCALE_3x48, + &&POST_OPS_MATRIX_ADD_3x48, + &&POST_OPS_SWISH_3x48, + &&POST_OPS_MATRIX_MUL_3x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,0,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,0,2,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_3x48: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x48_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + } +} + +// 2x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_2x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x48_DISABLE, + &&POST_OPS_BIAS_2x48, + &&POST_OPS_RELU_2x48, + &&POST_OPS_RELU_SCALE_2x48, + &&POST_OPS_GELU_TANH_2x48, + &&POST_OPS_GELU_ERF_2x48, + &&POST_OPS_CLIP_2x48, + &&POST_OPS_DOWNSCALE_2x48, + &&POST_OPS_MATRIX_ADD_2x48, + &&POST_OPS_SWISH_2x48, + &&POST_OPS_MATRIX_MUL_2x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,0,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,0,1,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x48: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x48_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + } +} + +// 1x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_1x48) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x48_DISABLE, + &&POST_OPS_BIAS_1x48, + &&POST_OPS_RELU_1x48, + &&POST_OPS_RELU_SCALE_1x48, + &&POST_OPS_GELU_TANH_1x48, + &&POST_OPS_GELU_ERF_1x48, + &&POST_OPS_CLIP_1x48, + &&POST_OPS_DOWNSCALE_1x48, + &&POST_OPS_MATRIX_ADD_1x48, + &&POST_OPS_SWISH_1x48, + &&POST_OPS_MATRIX_MUL_1x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,0,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,0,0,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + +POST_OPS_DOWNSCALE_1x48: + { + __m512 selector3 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x48_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + } +} +#endif +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index 36bc91d78f..2bad062cc0 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,7 +39,7 @@ #include "lpgemm_f32_kern_macros.h" -#ifndef LPGEMM_BF16_NOT_SUPPORTED +#ifndef LPGEMM_BF16_JIT // 6xlt16 bf16 fringe kernel LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) { @@ -52,7 +52,10 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) &&POST_OPS_GELU_TANH_6xLT16, &&POST_OPS_GELU_ERF_6xLT16, &&POST_OPS_CLIP_6xLT16, - &&POST_OPS_DOWNSCALE_6xLT16 + &&POST_OPS_DOWNSCALE_6xLT16, + &&POST_OPS_MATRIX_ADD_6xLT16, + &&POST_OPS_SWISH_6xLT16, + &&POST_OPS_MATRIX_MUL_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -70,9 +73,6 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) // A matrix storage bfloat type __m512bh a_bf16_0; - // For corner cases. - float buf0[16]; - dim_t value; if(k_full_pieces > 40) @@ -338,27 +338,27 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_0p0, ir, 0, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, ir, 0, 0, \ selector1, selector2); // c[1,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_1p0, ir, 1, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, ir, 1, 0, \ selector1, selector2); // c[2,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_2p0, ir, 2, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, ir, 2, 0, \ selector1, selector2); // c[3,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_3p0, ir, 3, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, ir, 3, 0, \ selector1, selector2); // c[4,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_4p0, ir, 4, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_4p0, ir, 4, 0, \ selector1, selector2); // c[5,0-15] - F32_F32_BETA_OP_NLT16F_MASK(load_mask, c_float_5p0, ir, 5, 0, \ + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_5p0, ir, 5, 0, \ selector1, selector2); } } @@ -370,9 +370,21 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -394,24 +406,41 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 5 ) ); + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -539,7 +568,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) { __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); - + // c[0, 0-15] CLIP_F32_AVX512(c_float_0p0, min, max) @@ -560,29 +589,314 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6xLT16: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); - // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); - // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); - // c[5, 0-15] - MULRND_F32(c_float_5p0,5,0); + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6xLT16_DISABLE: ; // Store the results. @@ -613,7 +927,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) else { __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - + // Store the results. // c[0,0-15] _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 0 ) ), load_mask, c_float_0p0 ); @@ -721,7 +1035,10 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) &&POST_OPS_GELU_TANH_6x16, &&POST_OPS_GELU_ERF_6x16, &&POST_OPS_CLIP_6x16, - &&POST_OPS_DOWNSCALE_6x16 + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16, + &&POST_OPS_MATRIX_MUL_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1033,9 +1350,17 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1057,24 +1382,41 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 5 ) ); + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1223,29 +1565,312 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x16: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); - // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); - // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); - // c[5, 0-15] - MULRND_F32(c_float_5p0,5,0); + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16_DISABLE: ; if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) @@ -1383,7 +2008,10 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) &&POST_OPS_GELU_TANH_6x32, &&POST_OPS_GELU_ERF_6x32, &&POST_OPS_CLIP_6x32, - &&POST_OPS_DOWNSCALE_6x32 + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32, + &&POST_OPS_MATRIX_MUL_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1711,7 +2339,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) // c[5, 16-31] BF16_F32_BETA_OP( c_float_5p1, ir, 5, 1, selector1, selector2 ); } - else + else { // c[0,0-15] F32_F32_BETA_OP( c_float_0p0, ir, 0, 0, selector1, selector2 ); @@ -1749,7 +2377,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) // c[5, 16-31] F32_F32_BETA_OP( c_float_5p1, ir, 5, 1, selector1, selector2 ); } - + } // Post Ops lpgemm_post_op* post_ops_list_temp = post_ops_list; @@ -1759,12 +2387,21 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -1804,24 +2441,41 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 5 ) ); + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2078,55 +2732,378 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x32: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } - // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } - // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); - // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); - // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); - // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); - // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); - // c[5, 0-15] - MULRND_F32(c_float_5p0,5,0); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); - // c[5, 16-31] - MULRND_F32(c_float_5p1,5,1); + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x32_DISABLE: - ; - if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) - { - // Generate a mask16 of all 1's. - __m512i selector_a = _mm512_setzero_epi32(); - __m512i selector_b = _mm512_set1_epi32( 10 ); - __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector5,zero_point4); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector6,zero_point5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); // Store the results in downscaled type (int8 instead of int32). // c[0,0-15] @@ -2292,7 +3269,10 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) &&POST_OPS_GELU_TANH_6x48, &&POST_OPS_GELU_ERF_6x48, &&POST_OPS_CLIP_6x48, - &&POST_OPS_DOWNSCALE_6x48 + &&POST_OPS_DOWNSCALE_6x48, + &&POST_OPS_MATRIX_ADD_6x48, + &&POST_OPS_SWISH_6x48, + &&POST_OPS_MATRIX_MUL_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -2748,15 +3728,25 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -2814,24 +3804,40 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) } else { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_ops_attr.post_op_c_i + 5 ) ); + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } // c[0,0-15] c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); @@ -3197,65 +4203,426 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x48: { - // c[0, 0-15] - MULRND_F32(c_float_0p0,0,0); + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); - // c[0, 16-31] - MULRND_F32(c_float_0p1,0,1); + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); - // c[0, 32-47] - MULRND_F32(c_float_0p2,0,2); + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); - // c[1, 0-15] - MULRND_F32(c_float_1p0,1,0); + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); - // c[1, 16-31] - MULRND_F32(c_float_1p1,1,1); + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); - // c[1, 32-47] - MULRND_F32(c_float_1p2,1,2); + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); - // c[2, 0-15] - MULRND_F32(c_float_2p0,2,0); + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); - // c[2, 16-31] - MULRND_F32(c_float_2p1,2,1); + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); - // c[2, 32-47] - MULRND_F32(c_float_2p2,2,2); + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); - // c[3, 0-15] - MULRND_F32(c_float_3p0,3,0); + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); - // c[3, 16-31] - MULRND_F32(c_float_3p1,3,1); + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); - // c[3, 32-47] - MULRND_F32(c_float_3p2,3,2); + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); - // c[4, 0-15] - MULRND_F32(c_float_4p0,4,0); + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); - // c[4, 16-31] - MULRND_F32(c_float_4p1,4,1); + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); - // c[4, 32-47] - MULRND_F32(c_float_4p2,4,2); + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); - // c[5, 0-15] - MULRND_F32(c_float_5p0,5,0); + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } - // c[5, 16-31] - MULRND_F32(c_float_5p1,5,1); + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); - // c[5, 32-47] - MULRND_F32(c_float_5p2,5,2); + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector2,zero_point1); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector2,zero_point1); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(c_float_5p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48_DISABLE: ; // Case where the output C matrix is bf16 (downscaled) and this is the diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16s4f32of32_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16s4f32of32_amd512vnni.c new file mode 100644 index 0000000000..075fa4dee1 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16s4f32of32_amd512vnni.c @@ -0,0 +1,5103 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" +#include "../int4_utils_avx512.h" + +#ifndef LPGEMM_BF16_JIT + +// 6xlt16 bf16s4f32of32 fringe kernel +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_6xlt16m) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6xLT16_DISABLE, + &&POST_OPS_BIAS_6xLT16, + &&POST_OPS_RELU_6xLT16, + &&POST_OPS_RELU_SCALE_6xLT16, + &&POST_OPS_GELU_TANH_6xLT16, + &&POST_OPS_GELU_ERF_6xLT16, + &&POST_OPS_CLIP_6xLT16, + &&POST_OPS_DOWNSCALE_6xLT16, + &&POST_OPS_MATRIX_ADD_6xLT16, + &&POST_OPS_SWISH_6xLT16, + &&POST_OPS_MATRIX_MUL_6xLT16 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // load and interleave scale factor vectors + scale0 = _mm512_maskz_loadu_ps( load_mask, + (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces - value; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + _mm_prefetch(c + (rs_c * (ir + 0)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 1)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 2)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 3)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 4)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 5)) + (0 * 16), _MM_HINT_T1); + + for (dim_t kr = k_full_pieces - value; kr < k_full_pieces; kr += 1) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 0) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps(c_float_0p0, a_bf16_0, b0); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 1) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps(c_float_1p0, a_bf16_0, b0); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 2) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps(c_float_2p0, a_bf16_0, b0); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 3) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps(c_float_3p0, a_bf16_0, b0); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 4) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps(c_float_4p0, a_bf16_0, b0); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 5) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps(c_float_5p0, a_bf16_0, b0); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 5) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_0p0, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_1p0, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_2p0, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_3p0, 3, 0, \ + selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_4p0, 4, 0, \ + selector1, selector2 ); + + // c[5,0-15] + BF16_F32_BETA_OP_NLT16F_MASK( load_mask, c_float_5p0, 5, 0, \ + selector1, selector2 ); + } + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // c[0,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_0p0, ir, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_1p0, ir, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_2p0, ir, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_3p0, ir, 3, 0, \ + selector1, selector2); + + // c[4,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_4p0, ir, 4, 0, \ + selector1, selector2); + + // c[5,0-15] + F32_F32_BETA_OP_NLT16F_MASK(c, load_mask, c_float_5p0, ir, 5, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xLT16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_maskz_loadu_ps + ( + bias_mask, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6xLT16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6xLT16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6xLT16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6xLT16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_MUL_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xLT16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xLT16_DISABLE: + ; + // Store the results. + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + __mmask16 mask_all1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[5,0-15] + CVT_STORE_F32_BF16_MASK(c_float_5p0,5,0); + } + + else + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 0 ) ), load_mask, c_float_0p0 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 1 ) ), load_mask, c_float_1p0 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 2 ) ), load_mask, c_float_2p0 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 3 ) ), load_mask, c_float_3p0 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 4 ) ), load_mask, c_float_4p0 ); + + // c[5,0-15] + _mm512_mask_storeu_ps( c + ( rs_c * ( ir + 5 ) ), load_mask, c_float_5p0 ); + + } + + a = a + ( MR * ps_a ); + post_ops_attr.post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16s4f32of32_5xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16s4f32of32_4xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16s4f32of32_3xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16s4f32of32_2xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16s4f32of32_1xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + post_ops_list, post_ops_attr + ); + } + } + +} + +// 6x16 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_6x16m) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_GELU_TANH_6x16, + &&POST_OPS_GELU_ERF_6x16, + &&POST_OPS_CLIP_6x16, + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16, + &&POST_OPS_MATRIX_MUL_6x16 + }; + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m128i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m256i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces - value; kr += 1 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + _mm_prefetch(c + (rs_c * (ir + 0)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 1)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 2)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 3)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 4)) + (0 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 5)) + (0 * 16), _MM_HINT_T1); + + for (dim_t kr = k_full_pieces - value; kr < k_full_pieces; kr += 1) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 0) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps(c_float_0p0, a_bf16_0, b0); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 1) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps(c_float_1p0, a_bf16_0, b0); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 2) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps(c_float_2p0, a_bf16_0, b0); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 3) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps(c_float_3p0, a_bf16_0, b0); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 4) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps(c_float_4p0, a_bf16_0, b0); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 5) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps(c_float_5p0, a_bf16_0, b0); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_8( b0_s8, 0, scale0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 5) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, ir, 0, 0, \ + selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, ir, 1, 0, \ + selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, ir, 2, 0, \ + selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, ir, 3, 0, \ + selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP( c_float_4p0, ir, 4, 0, \ + selector1, selector2 ); + + // c[5,0-15] + BF16_F32_BETA_OP( c_float_5p0, ir, 5, 0, \ + selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0, ir, 0, 0, \ + selector1, selector2); + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0, ir, 1, 0, \ + selector1, selector2); + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0, ir, 2, 0, \ + selector1, selector2); + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0, ir, 3, 0, \ + selector1, selector2); + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0, ir, 4, 0, \ + selector1, selector2); + + // c[5,0-15] + F32_F32_BETA_OP(c_float_5p0, ir, 5, 0, \ + selector1, selector2); + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x16: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x16: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6x16: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x16: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + // Also the same value is loaded to different registers so that + // branching can be reduced and same code/register can be used + // irrespective of whether scalar or vector op. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_ps( zp_mask, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_ADD_1COL(selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,4); + + // c[5:0-15] + BF16_F32_MATRIX_MUL_1COL(selector1,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,0); + + // c[1:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,1); + + // c[2:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,2); + + // c[3:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,3); + + // c[4:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,4); + + // c[5:0-15] + F32_F32_MATRIX_MUL_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[5,0-15] + CVT_STORE_F32_BF16_MASK(c_float_5p0,5,0); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + } + + a = a + ( MR * ps_a ); + post_ops_attr.post_op_c_i += MR; + } + + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16s4f32of32_5x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16s4f32of32_4x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16s4f32of32_3x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16s4f32of32_2x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16s4f32of32_1x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + } + +} + +// 6x32 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_6x32m) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_GELU_TANH_6x32, + &&POST_OPS_GELU_ERF_6x32, + &&POST_OPS_CLIP_6x32, + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32, + &&POST_OPS_MATRIX_MUL_6x32 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + __m256i b0_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces - value; kr += 1 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + + _mm_prefetch(c + (rs_c * (ir + 0)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (1 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 1)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (1 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 2)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (1 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 3)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (1 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 4)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (1 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 5)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (1 * 16), _MM_HINT_T1); + + for (dim_t kr = k_full_pieces - value; kr < k_full_pieces; kr += 1) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 0) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps(c_float_0p0, a_bf16_0, b0); + c_float_0p1 = _mm512_dpbf16_ps(c_float_0p1, a_bf16_0, b1); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 1) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps(c_float_1p0, a_bf16_0, b0); + c_float_1p1 = _mm512_dpbf16_ps(c_float_1p1, a_bf16_0, b1); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 2) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps(c_float_2p0, a_bf16_0, b0); + c_float_2p1 = _mm512_dpbf16_ps(c_float_2p1, a_bf16_0, b1); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 3) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps(c_float_3p0, a_bf16_0, b0); + c_float_3p1 = _mm512_dpbf16_ps(c_float_3p1, a_bf16_0, b1); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 4) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps(c_float_4p0, a_bf16_0, b0); + c_float_4p1 = _mm512_dpbf16_ps(c_float_4p1, a_bf16_0, b1); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 5) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps(c_float_5p0, a_bf16_0, b0); + c_float_5p1 = _mm512_dpbf16_ps(c_float_5p1, a_bf16_0, b1); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 5) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + + // c[0,0-15] + BF16_F32_BETA_OP( c_float_0p0, ir, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + BF16_F32_BETA_OP( c_float_0p1, ir, 0, 1, selector1, selector2 ); + + // c[1,0-15] + BF16_F32_BETA_OP( c_float_1p0, ir, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + BF16_F32_BETA_OP( c_float_1p1, ir, 1, 1, selector1, selector2 ); + + // c[2,0-15] + BF16_F32_BETA_OP( c_float_2p0, ir, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + BF16_F32_BETA_OP( c_float_2p1, ir, 2, 1, selector1, selector2 ); + + // c[3,0-15] + BF16_F32_BETA_OP( c_float_3p0, ir, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + BF16_F32_BETA_OP( c_float_3p1, ir, 3, 1, selector1, selector2 ); + + // c[4,0-15] + BF16_F32_BETA_OP( c_float_4p0, ir, 4, 0, selector1, selector2 ); + + // c[4, 16-31] + BF16_F32_BETA_OP( c_float_4p1, ir, 4, 1, selector1, selector2 ); + + // c[5,0-15] + BF16_F32_BETA_OP( c_float_5p0, ir, 5, 0, selector1, selector2 ); + + // c[5, 16-31] + BF16_F32_BETA_OP( c_float_5p1, ir, 5, 1, selector1, selector2 ); + } + else + { + // c[0,0-15] + F32_F32_BETA_OP( c_float_0p0, ir, 0, 0, selector1, selector2 ); + + // c[0, 16-31] + F32_F32_BETA_OP( c_float_0p1, ir, 0, 1, selector1, selector2 ); + + // c[1,0-15] + F32_F32_BETA_OP( c_float_1p0, ir, 1, 0, selector1, selector2 ); + + // c[1, 16-31] + F32_F32_BETA_OP( c_float_1p1, ir, 1, 1, selector1, selector2 ); + + // c[2,0-15] + F32_F32_BETA_OP( c_float_2p0, ir, 2, 0, selector1, selector2 ); + + // c[2, 16-31] + F32_F32_BETA_OP( c_float_2p1, ir, 2, 1, selector1, selector2 ); + + // c[3,0-15] + F32_F32_BETA_OP( c_float_3p0, ir, 3, 0, selector1, selector2 ); + + // c[3, 16-31] + F32_F32_BETA_OP( c_float_3p1, ir, 3, 1, selector1, selector2 ); + + // c[4,0-15] + F32_F32_BETA_OP( c_float_4p0, ir, 4, 0, selector1, selector2 ); + + // c[4, 16-31] + F32_F32_BETA_OP( c_float_4p1, ir, 4, 1, selector1, selector2 ); + + // c[5,0-15] + F32_F32_BETA_OP( c_float_5p0, ir, 5, 0, selector1, selector2 ); + + // c[5, 16-31] + F32_F32_BETA_OP( c_float_5p1, ir, 5, 1, selector1, selector2 ); + } + + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + } + else + { + __m512 selector3; + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x32: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 16-31] + GELU_TANH_F32_AVX512(c_float_5p1, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x32: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + // c[5, 16-31] + GELU_ERF_F32_AVX512(c_float_5p1, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6x32: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + // c[5, 16-31] + CLIP_F32_AVX512(c_float_5p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x32: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + __m512 selector5 = _mm512_setzero_ps(); + __m512 selector6 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + __m512 zero_point4 = _mm512_setzero_ps(); + __m512 zero_point5 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + zero_point4 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point5 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector5,zero_point4); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector5,zero_point4); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector6,zero_point5); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector6,zero_point5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + BF16_F32_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + BF16_F32_MATRIX_MUL_2COL(selector1,selector2,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (int8 instead of int32). + // c[0,0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[1,0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[2,0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[3,0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[4,0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + + // c[5,0-15] + CVT_STORE_F32_BF16_MASK(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_5p1,5,1); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + } + + a = a + ( MR * ps_a ); + post_ops_attr.post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16s4f32of32_5x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16s4f32of32_4x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16s4f32of32_3x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16s4f32of32_2x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16s4f32of32_1x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + } + +} + +// 6x48 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, int8_t, float, bf16s4f32of32_6x48m) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x48_DISABLE, + &&POST_OPS_BIAS_6x48, + &&POST_OPS_RELU_6x48, + &&POST_OPS_RELU_SCALE_6x48, + &&POST_OPS_GELU_TANH_6x48, + &&POST_OPS_GELU_ERF_6x48, + &&POST_OPS_CLIP_6x48, + &&POST_OPS_DOWNSCALE_6x48, + &&POST_OPS_MATRIX_ADD_6x48, + &&POST_OPS_SWISH_6x48, + &&POST_OPS_MATRIX_MUL_6x48 + }; + + dim_t pre_op_off = post_ops_attr.pre_op_off; + + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + __m256i b0_s4; + __m128i b1_s4; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + dim_t value; + + if(k_full_pieces > 40) + { + value = 40; + } + else + { + value = 0; + } + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + + bool signed_upscale = true; + + /* regs to store intermediate int8 values */ + __m512i b0_s8; + __m256i b1_s8; + + /* Regs to store F32 scale values */ + __m512 scale0, scale1, scale2, scale3, scale4, scale5; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( post_ops_attr.pre_op_scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + scale0 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off); + scale2 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 16 ); + scale4 = _mm512_loadu_ps( (float*)( post_ops_attr.pre_op_scale_factor ) + + pre_op_off + 32 ); + + scale1 = _mm512_permutex2var_ps( scale0, mask_scale2, scale0 ); + scale0 = _mm512_permutex2var_ps( scale0, mask_scale1, scale0 ); + scale3 = _mm512_permutex2var_ps( scale2, mask_scale2, scale2 ); + scale2 = _mm512_permutex2var_ps( scale2, mask_scale1, scale2 ); + scale5 = _mm512_permutex2var_ps( scale4, mask_scale2, scale4 ); + scale4 = _mm512_permutex2var_ps( scale4, mask_scale1, scale4 ); + + } + else + { + scale0 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale1 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale2 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale3 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale4 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + scale5 = _mm512_set1_ps( *( ( float* )post_ops_attr.pre_op_scale_factor ) ); + } + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces - value; kr += 1 ) + { + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + + } + + _mm_prefetch(c + (rs_c * (ir + 0)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 0)) + (2 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 1)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 1)) + (2 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 2)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 2)) + (2 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 3)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 3)) + (2 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 4)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 4)) + (2 * 16), _MM_HINT_T1); + + _mm_prefetch(c + (rs_c * (ir + 5)) + (0 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (1 * 16), _MM_HINT_T1); + _mm_prefetch(c + (rs_c * (ir + 5)) + (2 * 16), _MM_HINT_T1); + + for (dim_t kr = k_full_pieces - value; kr < k_full_pieces; kr += 1) + { + + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * kr ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * kr ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 0) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps(c_float_0p0, a_bf16_0, b0); + c_float_0p1 = _mm512_dpbf16_ps(c_float_0p1, a_bf16_0, b1); + c_float_0p2 = _mm512_dpbf16_ps(c_float_0p2, a_bf16_0, b2); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 1) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps(c_float_1p0, a_bf16_0, b0); + c_float_1p1 = _mm512_dpbf16_ps(c_float_1p1, a_bf16_0, b1); + c_float_1p2 = _mm512_dpbf16_ps(c_float_1p2, a_bf16_0, b2); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 2) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps(c_float_2p0, a_bf16_0, b0); + c_float_2p1 = _mm512_dpbf16_ps(c_float_2p1, a_bf16_0, b1); + c_float_2p2 = _mm512_dpbf16_ps(c_float_2p2, a_bf16_0, b2); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 3) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps(c_float_3p0, a_bf16_0, b0); + c_float_3p1 = _mm512_dpbf16_ps(c_float_3p1, a_bf16_0, b1); + c_float_3p2 = _mm512_dpbf16_ps(c_float_3p2, a_bf16_0, b2); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 4) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps(c_float_4p0, a_bf16_0, b0); + c_float_4p1 = _mm512_dpbf16_ps(c_float_4p1, a_bf16_0, b1); + c_float_4p2 = _mm512_dpbf16_ps(c_float_4p2, a_bf16_0, b2); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = (__m512bh)_mm512_set1_epi32( + *(int32_t *)(a + (rs_a * 5) + (cs_a * kr))); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps(c_float_5p0, a_bf16_0, b0); + c_float_5p1 = _mm512_dpbf16_ps(c_float_5p1, a_bf16_0, b1); + c_float_5p2 = _mm512_dpbf16_ps(c_float_5p2, a_bf16_0, b2); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0_s4 = _mm256_loadu_si256( (__m256i const *)( b + ( rs_b * k_full_pieces ) / 2 ) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( b0_s4, b0_s8, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 1, scale1 ), + CVT_INT8_F32_SCAL_16( b0_s8, 0, scale0 ) ); + + b1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( b0_s8, 3, scale3 ), + CVT_INT8_F32_SCAL_16( b0_s8, 2, scale2 ) ); + + b1_s4 = _mm_loadu_si128( (__m128i const *)( b + ( ( rs_b * k_full_pieces ) / 2 ) + 32 ) ); + + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT( b1_s4, b1_s8, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_8( b1_s8, 1, scale5 ), + CVT_INT8_F32_SCAL_8( b1_s8, 0, scale4 ) ); + + // Broadcast a[0,kr:kr+4]. + a_kfringe_buf = *( a + (rs_a * 0) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 1) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 2) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 3) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 4) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + a_kfringe_buf = *(a + (rs_a * 5) + (cs_a * ( k_full_pieces ))); + a_bf16_0 = (__m512bh)_mm512_set1_epi16( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + if ( alpha != 1 ) + { + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + } + + // Scale C by beta. + if ( beta != 0 ) + { + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_first_k == TRUE ) ) + { + // c[0,0-15] + BF16_F32_BETA_OP(c_float_0p0,ir,0,0,selector1,selector2) + + // c[0, 16-31] + BF16_F32_BETA_OP(c_float_0p1,ir,0,1,selector1,selector2) + + // c[0,32-47] + BF16_F32_BETA_OP(c_float_0p2,ir,0,2,selector1,selector2) + + // c[1,0-15] + BF16_F32_BETA_OP(c_float_1p0,ir,1,0,selector1,selector2) + + // c[1,16-31] + BF16_F32_BETA_OP(c_float_1p1,ir,1,1,selector1,selector2) + + // c[1,32-47] + BF16_F32_BETA_OP(c_float_1p2,ir,1,2,selector1,selector2) + + // c[2,0-15] + BF16_F32_BETA_OP(c_float_2p0,ir,2,0,selector1,selector2) + + // c[2,16-31] + BF16_F32_BETA_OP(c_float_2p1,ir,2,1,selector1,selector2) + + // c[2,32-47] + BF16_F32_BETA_OP(c_float_2p2,ir,2,2,selector1,selector2) + + // c[3,0-15] + BF16_F32_BETA_OP(c_float_3p0,ir,3,0,selector1,selector2) + + // c[3,16-31] + BF16_F32_BETA_OP(c_float_3p1,ir,3,1,selector1,selector2) + + // c[3,32-47] + BF16_F32_BETA_OP(c_float_3p2,ir,3,2,selector1,selector2) + + // c[4,0-15] + BF16_F32_BETA_OP(c_float_4p0,ir,4,0,selector1,selector2) + + // c[4,16-31] + BF16_F32_BETA_OP(c_float_4p1,ir,4,1,selector1,selector2) + + // c[4,32-47] + BF16_F32_BETA_OP(c_float_4p2,ir,4,2,selector1,selector2) + + // c[5,0-15] + BF16_F32_BETA_OP(c_float_5p0,ir,5,0,selector1,selector2) + + // c[5,16-31] + BF16_F32_BETA_OP(c_float_5p1,ir,5,1,selector1,selector2) + + // c[5,32-47] + BF16_F32_BETA_OP(c_float_5p2,ir,5,2,selector1,selector2) + } + else + { + // c[0,0-15] + F32_F32_BETA_OP(c_float_0p0,ir,0,0,selector1,selector2) + + // c[0, 16-31] + F32_F32_BETA_OP(c_float_0p1,ir,0,1,selector1,selector2) + + // c[0,32-47] + F32_F32_BETA_OP(c_float_0p2,ir,0,2,selector1,selector2) + + // c[1,0-15] + F32_F32_BETA_OP(c_float_1p0,ir,1,0,selector1,selector2) + + // c[1,16-31] + F32_F32_BETA_OP(c_float_1p1,ir,1,1,selector1,selector2) + + // c[1,32-47] + F32_F32_BETA_OP(c_float_1p2,ir,1,2,selector1,selector2) + + // c[2,0-15] + F32_F32_BETA_OP(c_float_2p0,ir,2,0,selector1,selector2) + + // c[2,16-31] + F32_F32_BETA_OP(c_float_2p1,ir,2,1,selector1,selector2) + + // c[2,32-47] + F32_F32_BETA_OP(c_float_2p2,ir,2,2,selector1,selector2) + + // c[3,0-15] + F32_F32_BETA_OP(c_float_3p0,ir,3,0,selector1,selector2) + + // c[3,16-31] + F32_F32_BETA_OP(c_float_3p1,ir,3,1,selector1,selector2) + + // c[3,32-47] + F32_F32_BETA_OP(c_float_3p2,ir,3,2,selector1,selector2) + + // c[4,0-15] + F32_F32_BETA_OP(c_float_4p0,ir,4,0,selector1,selector2) + + // c[4,16-31] + F32_F32_BETA_OP(c_float_4p1,ir,4,1,selector1,selector2) + + // c[4,32-47] + F32_F32_BETA_OP(c_float_4p2,ir,4,2,selector1,selector2) + + // c[5,0-15] + F32_F32_BETA_OP(c_float_5p0,ir,5,0,selector1,selector2) + + // c[5,16-31] + F32_F32_BETA_OP(c_float_5p1,ir,5,1,selector1,selector2) + + // c[5,32-47] + F32_F32_BETA_OP(c_float_5p2,ir,5,2,selector1,selector2) + } + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x48: + { + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_LOAD(selector1, bias_mask, 0); + BF16_F32_BIAS_LOAD(selector2, bias_mask, 1); + BF16_F32_BIAS_LOAD(selector3, bias_mask, 2); + } + else + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + } + else + { + __m512 selector4; + __m512 selector5; + __m512 selector6; + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + BF16_F32_BIAS_BCAST(selector2, bias_mask, 1); + BF16_F32_BIAS_BCAST(selector3, bias_mask, 2); + BF16_F32_BIAS_BCAST(selector4, bias_mask, 3); + BF16_F32_BIAS_BCAST(selector5, bias_mask, 4); + BF16_F32_BIAS_BCAST(selector6, bias_mask, 5); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + } + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x48: + { + //printf("relu\n"); + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x48: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32_AVX512(c_float_0p0, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32_AVX512(c_float_0p1, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32_AVX512(c_float_0p2, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32_AVX512(c_float_1p0, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32_AVX512(c_float_1p1, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32_AVX512(c_float_1p2, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32_AVX512(c_float_2p0, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32_AVX512(c_float_2p1, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32_AVX512(c_float_2p2, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32_AVX512(c_float_3p0, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32_AVX512(c_float_3p1, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32_AVX512(c_float_3p2, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32_AVX512(c_float_4p0, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32_AVX512(c_float_4p1, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32_AVX512(c_float_4p2, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32_AVX512(c_float_5p0, r, r2, x, z, dn, x_tanh, q) + + // c[5, 16-31] + GELU_TANH_F32_AVX512(c_float_5p1, r, r2, x, z, dn, x_tanh, q) + + // c[5, 32-47] + GELU_TANH_F32_AVX512(c_float_5p2, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x48: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32_AVX512(c_float_0p0, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32_AVX512(c_float_0p1, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32_AVX512(c_float_0p2, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32_AVX512(c_float_1p0, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32_AVX512(c_float_1p1, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32_AVX512(c_float_1p2, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32_AVX512(c_float_2p0, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32_AVX512(c_float_2p1, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32_AVX512(c_float_2p2, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32_AVX512(c_float_3p0, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32_AVX512(c_float_3p1, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32_AVX512(c_float_3p2, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32_AVX512(c_float_4p0, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32_AVX512(c_float_4p1, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32_AVX512(c_float_4p2, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32_AVX512(c_float_5p0, r, x, x_erf) + + // c[5, 16-31] + GELU_ERF_F32_AVX512(c_float_5p1, r, x, x_erf) + + // c[5, 32-47] + GELU_ERF_F32_AVX512(c_float_5p2, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6x48: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32_AVX512(c_float_0p0, min, max) + + // c[0, 16-31] + CLIP_F32_AVX512(c_float_0p1, min, max) + + // c[0, 32-47] + CLIP_F32_AVX512(c_float_0p2, min, max) + + // c[1, 0-15] + CLIP_F32_AVX512(c_float_1p0, min, max) + + // c[1, 16-31] + CLIP_F32_AVX512(c_float_1p1, min, max) + + // c[1, 32-47] + CLIP_F32_AVX512(c_float_1p2, min, max) + + // c[2, 0-15] + CLIP_F32_AVX512(c_float_2p0, min, max) + + // c[2, 16-31] + CLIP_F32_AVX512(c_float_2p1, min, max) + + // c[2, 32-47] + CLIP_F32_AVX512(c_float_2p2, min, max) + + // c[3, 0-15] + CLIP_F32_AVX512(c_float_3p0, min, max) + + // c[3, 16-31] + CLIP_F32_AVX512(c_float_3p1, min, max) + + // c[3, 32-47] + CLIP_F32_AVX512(c_float_3p2, min, max) + + // c[4, 0-15] + CLIP_F32_AVX512(c_float_4p0, min, max) + + // c[4, 16-31] + CLIP_F32_AVX512(c_float_4p1, min, max) + + // c[4, 32-47] + CLIP_F32_AVX512(c_float_4p2, min, max) + + // c[5, 0-15] + CLIP_F32_AVX512(c_float_5p0, min, max) + + // c[5, 16-31] + CLIP_F32_AVX512(c_float_5p1, min, max) + + // c[5, 32-47] + CLIP_F32_AVX512(c_float_5p2, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x48: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( zp_mask, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector3,zero_point2); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector1,zero_point0); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector3,zero_point2); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector1,zero_point0); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector2,zero_point1); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector1,zero_point0); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector2,zero_point1); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector3,zero_point2); + + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector2,zero_point1); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector3,zero_point2); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector1,zero_point0); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector3,zero_point2); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 3 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 2 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 3 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(c_float_0p0,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(c_float_0p1,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(c_float_0p2,selector1,zero_point0); + + // c[1, 0-15] + SCL_MULRND_F32(c_float_1p0,selector2,zero_point1); + + // c[1, 16-31] + SCL_MULRND_F32(c_float_1p1,selector2,zero_point1); + + // c[1, 32-47] + SCL_MULRND_F32(c_float_1p2,selector2,zero_point1); + + // c[2, 0-15] + SCL_MULRND_F32(c_float_2p0,selector3,zero_point2); + + // c[2, 16-31] + SCL_MULRND_F32(c_float_2p1,selector3,zero_point2); + + // c[2, 32-47] + SCL_MULRND_F32(c_float_2p2,selector3,zero_point2); + + // c[3, 0-15] + SCL_MULRND_F32(c_float_3p0,selector4,zero_point3); + + // c[3, 16-31] + SCL_MULRND_F32(c_float_3p1,selector4,zero_point3); + + // c[3, 32-47] + SCL_MULRND_F32(c_float_3p2,selector4,zero_point3); + + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 4 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 5 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 4 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 5 ) ) ); + } + // c[4, 0-15] + SCL_MULRND_F32(c_float_4p0,selector1,zero_point0); + + // c[4, 16-31] + SCL_MULRND_F32(c_float_4p1,selector1,zero_point0); + + // c[4, 32-47] + SCL_MULRND_F32(c_float_4p2,selector1,zero_point0); + + // c[5, 0-15] + SCL_MULRND_F32(c_float_5p0,selector2,zero_point1); + + // c[5, 16-31] + SCL_MULRND_F32(c_float_5p1,selector2,zero_point1); + + // c[5, 32-47] + SCL_MULRND_F32(c_float_5p2,selector2,zero_point1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + BF16_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(selector1,selector2,selector3,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x48: + { + __m512 selector3; + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + BF16_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,5); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,0); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,1); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,2); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,3); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,4); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(selector1,selector2,selector3,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x48: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(c_float_0p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(c_float_0p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(c_float_0p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(c_float_1p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(c_float_1p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(c_float_1p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(c_float_2p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(c_float_2p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(c_float_2p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(c_float_3p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(c_float_3p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(c_float_3p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(c_float_4p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(c_float_4p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(c_float_4p2, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(c_float_5p0, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(c_float_5p1, selector1, al_in, r, r2, z, dn, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(c_float_5p2, selector1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x48_DISABLE: + ; + // Case where the output C matrix is bf16 (downscaled) and this is the + // final write for a given block within C. + if ( ( post_ops_attr.buf_downscale != NULL ) && + ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + __m512i selector_a = _mm512_setzero_epi32(); + __m512i selector_b = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b ); + + // Store the results in downscaled type (bf16 instead of float). + + // c[0, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_4p2,4,2); + + // c[5, 0-15] + CVT_STORE_F32_BF16_MASK(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_STORE_F32_BF16_MASK(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_STORE_F32_BF16_MASK(c_float_5p2,5,2); + } + + else + { + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + } + + a = a + ( MR * ps_a ); + post_ops_attr.post_op_c_i += MR; + + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16s4f32of32_5x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16s4f32of32_4x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16s4f32of32_3x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16s4f32of32_2x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16s4f32of32_1x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + post_ops_list, post_ops_attr + ); + } + } + +} +#endif +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c index b928338f30..d3afdfefc6 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c @@ -379,6 +379,9 @@ void packa_mr16_bf16bf16f32of32_row_major { dim_t MR = 32; + dim_t k_left = KC % 32; + + __mmask32 mask = 0xFFFFFFFF >> ( 32 - k_left ); __m512i a_reg[32]; dim_t ic = 0, kr = 0; @@ -454,36 +457,13 @@ void packa_mr16_bf16bf16f32of32_row_major _mm512_storeu_si512( pack_a_buffer + ( ( ic + 30 ) * KC ) + kr , a_reg[30] ); _mm512_storeu_si512( pack_a_buffer + ( ( ic + 31 ) * KC ) + kr , a_reg[31] ); } - for( ; ( kr + 15 ) < KC; kr += 16 ) - { - MASKED_LOAD_32_ROWS_AVX512( 0xFFFF ) - - MASKED_STORE_32_ROWS_AVX512( 0xFFFF ) - } - for( ; ( kr + 7 ) < KC; kr += 8 ) - { - MASKED_LOAD_32_ROWS_AVX512( 0xFF ) - - MASKED_STORE_32_ROWS_AVX512( 0xFF ) - } - for( ; ( kr + 3 ) < KC; kr += 4 ) - { - MASKED_LOAD_32_ROWS_AVX512( 0xF ) - MASKED_STORE_32_ROWS_AVX512( 0xF ) - } - for( ; ( kr + 1 ) < KC; kr += 2 ) + if( k_left > 0 ) { - MASKED_LOAD_32_ROWS_AVX512( 0x3 ) - - MASKED_STORE_32_ROWS_AVX512( 0x3 ) + MASKED_LOAD_32_ROWS_AVX512( mask ) + MASKED_STORE_32_ROWS_AVX512( mask ) } - for( ; ( kr ) < KC; kr += 1 ) - { - MASKED_LOAD_32_ROWS_AVX512( 0x1 ) - MASKED_STORE_32_ROWS_AVX512( 0x1 ) - } } for( ; ( ic + 16 - 1 ) < MC; ic += 16 ) { @@ -523,37 +503,15 @@ void packa_mr16_bf16bf16f32of32_row_major _mm512_storeu_si512( pack_a_buffer + ( ( ic + 14 ) * KC ) + kr , a_reg[14] ); _mm512_storeu_si512( pack_a_buffer + ( ( ic + 15 ) * KC ) + kr , a_reg[15] ); } - for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) - { - MASKED_LOAD_16_ROWS_AVX512( 0xFFFF ) - MASKED_STORE_16_ROWS_AVX512( 0xFFFF ) - } - for( ; ( kr + 7 ) < KC; kr += 8 ) - { - MASKED_LOAD_16_ROWS_AVX512( 0xFF ) - - MASKED_STORE_16_ROWS_AVX512( 0xFF ) - } - for( ; ( kr + 3 ) < KC; kr += 4 ) - { - MASKED_LOAD_16_ROWS_AVX512( 0xF ) - MASKED_STORE_16_ROWS_AVX512( 0xF ) - } - for( ; ( kr + 1 ) < KC; kr += 2 ) + if( k_left > 0 ) { - MASKED_LOAD_16_ROWS_AVX512( 0x3 ) - - MASKED_STORE_16_ROWS_AVX512( 0x3 ) + MASKED_LOAD_16_ROWS_AVX512( mask ) + MASKED_STORE_16_ROWS_AVX512( mask ) } - for( ; ( kr ) < KC; kr += 1 ) - { - MASKED_LOAD_16_ROWS_AVX512( 0x1 ) - MASKED_STORE_16_ROWS_AVX512( 0x1 ) - } } - for( ; ( ic + 7 - 1 ) < MC; ic += 8 ) + for( ; ( ic + 8 - 1 ) < MC; ic += 8 ) { for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) { @@ -575,35 +533,13 @@ void packa_mr16_bf16bf16f32of32_row_major _mm512_storeu_si512( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr , a_reg[6] ); _mm512_storeu_si512( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr , a_reg[7] ); } - for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) - { - MASKED_LOAD_8_ROWS_AVX512( 0xFFFF ) - MASKED_STORE_8_ROWS_AVX512( 0xFFFF ) - } - for( ; ( kr + 7 ) < KC; kr += 8 ) - { - MASKED_LOAD_8_ROWS_AVX512( 0xFF ) - - MASKED_STORE_8_ROWS_AVX512( 0xFF ) - } - for( ; ( kr + 3 ) < KC; kr += 4 ) - { - MASKED_LOAD_8_ROWS_AVX512( 0xF ) - MASKED_STORE_8_ROWS_AVX512( 0xF ) - } - for( ; ( kr + 1 ) < KC; kr += 2 ) + if( k_left > 0 ) { - MASKED_LOAD_8_ROWS_AVX512( 0x3 ) - - MASKED_STORE_8_ROWS_AVX512( 0x3 ) + MASKED_LOAD_8_ROWS_AVX512( mask ) + MASKED_STORE_8_ROWS_AVX512( mask ) } - for( ; ( kr ) < KC; kr += 1 ) - { - MASKED_LOAD_8_ROWS_AVX512( 0x1 ) - MASKED_STORE_8_ROWS_AVX512( 0x1 ) - } } for( ; ( ic + 4 - 1 ) < MC; ic += 4 ) { @@ -619,35 +555,13 @@ void packa_mr16_bf16bf16f32of32_row_major _mm512_storeu_si512( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr , a_reg[2] ); _mm512_storeu_si512( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr , a_reg[3] ); } - for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) - { - MASKED_LOAD_4_ROWS_AVX512( 0xFFFF ) - MASKED_STORE_4_ROWS_AVX512( 0xFFFF ) - } - for( ; ( kr + 7 ) < KC; kr += 8 ) - { - MASKED_LOAD_4_ROWS_AVX512( 0xFF ) - - MASKED_STORE_4_ROWS_AVX512( 0xFF ) - } - for( ; ( kr + 3 ) < KC; kr += 4 ) - { - MASKED_LOAD_4_ROWS_AVX512( 0xF ) - MASKED_STORE_4_ROWS_AVX512( 0xF ) - } - for( ; ( kr + 1 ) < KC; kr += 2 ) + if( k_left > 0 ) { - MASKED_LOAD_4_ROWS_AVX512( 0x3 ) - - MASKED_STORE_4_ROWS_AVX512( 0x3 ) + MASKED_LOAD_4_ROWS_AVX512( mask ) + MASKED_STORE_4_ROWS_AVX512( mask ) } - for( ; ( kr ) < KC; kr += 1 ) - { - MASKED_LOAD_4_ROWS_AVX512( 0x1 ) - MASKED_STORE_4_ROWS_AVX512( 0x1 ) - } } for( ; ( ic + 2 - 1 ) < MC; ic += 2 ) @@ -762,6 +676,27 @@ void packa_mr16_bf16bf16f32of32_col_major __m256i a_reg[16], b_reg[16]; + // These registers are set with zeroes to avoid compiler warnings + // To-DO: TO be removed when pack code is optimized for fringe cases. + + a_reg[0] = _mm256_setzero_si256(); + a_reg[1] = _mm256_setzero_si256(); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + dim_t ic, kr; for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR) @@ -821,14 +756,6 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[5] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) ); a_reg[6] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) ); a_reg[7] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) ); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -846,18 +773,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[1] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); a_reg[2] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); a_reg[3] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -872,20 +788,7 @@ void packa_mr16_bf16bf16f32of32_col_major { a_reg[0] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); a_reg[1] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -899,21 +802,7 @@ void packa_mr16_bf16bf16f32of32_col_major for( ; ( kr ) < KC; kr += 1) { a_reg[0] = _mm256_loadu_si256( (__m256i const *)(a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); - a_reg[1] = _mm256_setzero_si256(); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -925,7 +814,6 @@ void packa_mr16_bf16bf16f32of32_col_major MASKED_STORE_EPI16(0x01) } } - for( ; ( ic + 8 - 1) < MC; ic += 8) { for( kr = 0; ( kr + 15 ) < KC; kr += 16) @@ -963,7 +851,6 @@ void packa_mr16_bf16bf16f32of32_col_major _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), a_reg[12] ); _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), a_reg[10] ); _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), a_reg[14] ); - _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 8 ) * KC + kr ), a_reg[1] ); } for( ; ( kr + 7 ) < KC; kr += 8) @@ -976,14 +863,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[5] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); a_reg[6] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); a_reg[7] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 UNPACKLO_EPI32 @@ -1007,18 +887,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); a_reg[2] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); a_reg[3] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1040,20 +909,7 @@ void packa_mr16_bf16bf16f32of32_col_major { a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1074,21 +930,6 @@ void packa_mr16_bf16bf16f32of32_col_major for( ; ( kr ) < KC; kr += 1) { a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); - a_reg[1] = _mm256_setzero_si256(); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1153,14 +994,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[5] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); a_reg[6] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); a_reg[7] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 UNPACKLO_EPI32 @@ -1180,18 +1014,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); a_reg[2] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); a_reg[3] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1209,20 +1032,7 @@ void packa_mr16_bf16bf16f32of32_col_major { a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1239,21 +1049,7 @@ void packa_mr16_bf16bf16f32of32_col_major for( ; ( kr ) < KC; kr += 1) { a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); - a_reg[1] = _mm256_setzero_si256(); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1326,14 +1122,8 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[5] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); a_reg[6] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); a_reg[7] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 UNPACKHI_EPI16 UNPACKLO_EPI32 @@ -1364,18 +1154,7 @@ void packa_mr16_bf16bf16f32of32_col_major a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); a_reg[2] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); a_reg[3] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1405,20 +1184,7 @@ void packa_mr16_bf16bf16f32of32_col_major { a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 @@ -1446,21 +1212,7 @@ void packa_mr16_bf16bf16f32of32_col_major for( ; ( kr ) < KC; kr += 1) { a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); - a_reg[1] = _mm256_setzero_si256(); - a_reg[2] = _mm256_setzero_si256(); - a_reg[3] = _mm256_setzero_si256(); - a_reg[4] = _mm256_setzero_si256(); - a_reg[5] = _mm256_setzero_si256(); - a_reg[6] = _mm256_setzero_si256(); - a_reg[7] = _mm256_setzero_si256(); - a_reg[8] = _mm256_setzero_si256(); - a_reg[9] = _mm256_setzero_si256(); - a_reg[10] = _mm256_setzero_si256(); - a_reg[11] = _mm256_setzero_si256(); - a_reg[12] = _mm256_setzero_si256(); - a_reg[13] = _mm256_setzero_si256(); - a_reg[14] = _mm256_setzero_si256(); - a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 UNPACKHI_EPI16 diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c index 54d0fb86b8..71a20ffef3 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c @@ -359,6 +359,7 @@ void packb_nr48_bf16bf16f32of32_row_major kr_new += 3; } + // Handle k remainder. if ( k_partial_pieces > 0 ) { @@ -440,6 +441,7 @@ void packb_nr32_bf16bf16f32of32_row_major kr_new += 2; } + // Handle k remainder. if ( k_partial_pieces > 0 ) { @@ -503,6 +505,7 @@ void packb_nr16_bf16bf16f32of32_row_major kr_new += 2; } + // Handle k remainder. if ( k_partial_pieces > 0 ) { @@ -580,6 +583,7 @@ void packb_nrlt16_bf16bf16f32of32_row_major kr_new += 2; } + // Handle k remainder. if ( k_partial_pieces > 0 ) { @@ -816,7 +820,6 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major const dim_t KC ) { - // Used for permuting the mm512i elements for use in dpbf16_ps instruction. __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD ); __m512i selector2 = _mm512_setr_epi64( 0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF ); @@ -830,7 +833,6 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major for( dim_t jr = 0; jr < NR; jr += 16 ) { // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. - LOAD_16_COLS_AVX512 UNPACKHILO32_AVX512 UNPACKHILO64_AVX512 @@ -854,9 +856,9 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 26 ) * NR ), a_reg[13] ); _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 28 ) * NR ), a_reg[14] ); _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 30 ) * NR ), a_reg[15] ); - } } + for ( ; ( kr + 15 ) < KC; kr += 16 ) { for( dim_t jr = 0; jr < NR; jr += 16 ) @@ -900,6 +902,7 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 6 ) * NR ), a_reg[3] ); } } + for( ; ( kr +3 ) < KC; kr += 4 ) { for( dim_t jr = 0; jr < NR; jr += 16 ) @@ -916,6 +919,7 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 2 ) * NR ), a_reg[1] ); } } + for( ; ( kr +1 ) < KC; kr += 2 ) { for( dim_t jr = 0; jr < NR; jr += 16 ) @@ -931,6 +935,7 @@ void packb_nr_mult_16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr ) * NR ), a_reg[0] ); } } + for( ; kr < KC; kr += 1 ) { for( dim_t jr = 0; jr < NR; jr += 16 ) @@ -1004,6 +1009,7 @@ void packb_nrlt16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( ( kr + 30 ) * NR ), a_reg[15] ); } + for ( ; ( kr + 15 ) < KC; kr += 16 ) { for( jr = 0; jr < n0_partial_rem; jr += 1 ) @@ -1055,6 +1061,7 @@ void packb_nrlt16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( ( kr + 4 ) * NR ), a_reg[2] ); _mm512_storeu_si512( pack_b_buffer + ( ( kr + 6 ) * NR ), a_reg[3] ); } + for ( ; (kr+3) < KC; kr += 4 ) { for( jr = 0; jr < n0_partial_rem; jr += 1 ) @@ -1076,6 +1083,7 @@ void packb_nrlt16_bf16bf16f32of32_col_major _mm512_storeu_si512( pack_b_buffer + ( ( kr + 0 ) * NR ), a_reg[0] ); _mm512_storeu_si512( pack_b_buffer + ( ( kr + 2 ) * NR ), a_reg[1] ); } + for ( ; ( kr + 1 ) < KC; kr += 2 ) { for( jr = 0; jr < n0_partial_rem; jr += 1 ) @@ -1095,6 +1103,7 @@ void packb_nrlt16_bf16bf16f32of32_col_major // store to pack_b buffer _mm512_storeu_si512( pack_b_buffer + ( kr * NR ), a_reg[0] ); } + for ( ; kr < KC; kr += 1 ) { for( jr = 0; jr < n0_partial_rem; jr += 1 ) diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c new file mode 100644 index 0000000000..e62368e40a --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c @@ -0,0 +1,1492 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../int4_utils_avx512.h" + +void packb_nr64_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p, + lpgemm_pre_op* pre_op + ); + +void packb_nr64_bf16s4f32of32_col_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p, + lpgemm_pre_op* pre_op + ); + +void packb_nr48_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nr32_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nr16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nrlt16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr64_bf16s4f32of32 + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p, + lpgemm_pre_op* pre_op + ) +{ + if (cs_b == 1) + { + packb_nr64_bf16s4f32of32_row_major + ( + pack_b_buffer, b, rs_b, NC, + KC, rs_p, cs_p, pre_op + ); + } + else + { + packb_nr64_bf16s4f32of32_col_major + ( + pack_b_buffer, b, cs_b, NC, KC, + rs_p, cs_p, pre_op + ); + } +} + +void packb_nr64_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p, + lpgemm_pre_op* pre_op + ) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + // KC when not multiple of 2 will have padding to make it multiple of 2 + // in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + // Selectors for int4 -> int8 conversion. + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM( shift_idx_64 ); + + __m512i sign_comp = _mm512_set1_epi8( 0x08 ); + __mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4. + __mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr); + __m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(even_idx_arr) + __m512i even_perm_idx = _mm512_loadu_si512( even_idx_arr ); + __m512i all_1s = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x01 ); + __m512i odd_perm_idx = _mm512_add_epi8( even_perm_idx, all_1s ); + __m512i clear_hi_bits = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x0F ); + + __m256i h_a0; + __m256i h_b0; + __m256i h_b0_l4bit; + + __m512i a0; + __m512i b0; + __m512i r_lo; + __m512i r_hi; + __m512i s4_out; + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + h_a0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 0 ) ) + jc ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + // If the stride, i.e. rs_b is odd, then the stride increment + // (rs_b * ...)/2 will point at the byte of which the high 4 + // bits is our desired starting element. However since data + // access is at byte level, the low 4 bits of this byte will + // be wrongly included, and additionally the last int4 element + // won't be included either. Extra data movement done to + // account for the same. + // Since kr is a multiple of 2, only kr+1 will have the + // aforementioned issue. + if ( is_odd_stride == FALSE ) + { + h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \ + sign_comp, signed_upscale); + } + else + { + h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) ); + // Only load the last byte/ 32nd byte. + h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + } + + // Restructuring at int8 level. + r_lo = _mm512_unpacklo_epi8( a0, b0 ); + r_hi = _mm512_unpackhi_epi8( a0, b0 ); + + a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi ); + b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi ); + + // To be converted to int4 for storing. + CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \ + even_perm_idx, odd_perm_idx, clear_hi_bits); + + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + _mm512_storeu_si512( pack_b_buffer + + ( ( ( jc * KC_updated ) + ( kr * NR ) ) / incr_adj_factor ), + s4_out ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) / + incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_setzero_si512(); + + // Restructuring at int8 level. + r_lo = _mm512_unpacklo_epi8( a0, b0 ); + r_hi = _mm512_unpackhi_epi8( a0, b0 ); + + a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi ); + b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi ); + + // To be converted to int4 for storing. + CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \ + even_perm_idx, odd_perm_idx, clear_hi_bits); + + _mm512_storeu_si512( pack_b_buffer + + ( ( ( jc * KC_updated ) + ( k_full_pieces * NR ) ) / + incr_adj_factor ), s4_out ); + } + } + + if(n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16s4f32of32_row_major + ( + ( pack_b_buffer + ( ( ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ) / incr_adj_factor ) ), + ( b + ( ( n_full_pieces_loop_limit + n0_partial_pack ) / + incr_adj_factor ) ), + rs_b, KC, n0_partial_rem + ); + } + } + *rs_p = NR * 2; + *cs_p = NR / 2; +} + +void packb_nr48_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 48; + const dim_t NR_32x2 = 64; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 ); + __m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 ); + + // Selectors for int4 -> int8 conversion. + // First 32 int4 elements selectors. + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + // Next 16 int4 elements selectors. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + // First 32 int8 elements selectors. + CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32); + __m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_32 ); + __m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), + 0x01 ); + __m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 ); + __m256i clear_hi_bits_32 = + _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F ); + + // Next 16 int4 elements selectors. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_32; + __m128i h_b0_32; + __m128i h_b0_32_l4bit; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + __m256i a0_32; + __m256i b0_32; + __m256i r_lo_32; + __m256i r_hi_32; + __m256i s4_out_32; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // First 32 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + // First 32 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + // First 32 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + // Only load the last byte/ 16th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \ + signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + // Restructuring at int8 level. + // First 32 columns. + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_32 ); + + // Last 16 columns. + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( ( kr * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + // First 32 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + b0_32 = _mm256_setzero_si256(); + + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 ); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( ( k_full_pieces * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 ); + } +} + +void packb_nr32_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 32; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 ); + __m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 ); + + // Selectors for int4 -> int8 conversion. + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32); + __m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_32 ); + __m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), + 0x01 ); + __m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 ); + __m256i clear_hi_bits_32 = + _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F ); + + __m128i h_a0_32; + __m128i h_b0_32; + __m128i h_b0_32_l4bit; + __m256i a0_32; + __m256i b0_32; + __m256i r_lo_32; + __m256i r_hi_32; + __m256i s4_out_32; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + } + else + { + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + // Only load the last byte/ 16th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \ + signed_upscale); + } + + // Restructuring at int8 level. + // First 32 columns. + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_32 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + b0_32 = _mm256_setzero_si256(); + + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 ); + } +} + +void packb_nr16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Selectors for int4 -> int8 conversion. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_16; + __m128i h_b0_16; + __m128i h_b0_16_l4bit; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 ); + } +} + +void packb_nrlt16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + const dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Selectors for int4 -> int8 conversion. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + // 16 int4 elems in 8 bytes, so adjusting the mask for nr < 16 by + // a factor of 2. In case of odd remainder, the last int4 element + // within the last byte (hi 4 bits) will be ingnored similar to + // padding bits. + __mmask16 hmask_16; + if ( is_odd_stride == FALSE ) + { + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + else + { + if ( ( n0_partial_rem % 2 ) == 0 ) + { + // An interesting property here is that n0_partial_rem is + // guaranteed to be < 16. In that case the largest even n0 + // rem would be 14, and the max number of bytes that will be + // loaded including the extra 4 bit at the beginning will + // only be 7 bytes out of 8. So in any case loading 1 more + // byte will bring the last int4 in the register, while not + // crossing the register boundaries. + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( ( 16 - n0_partial_rem ) / 2 ) - 1 ) ); + } + else + { + // If the n0 rem is odd, and if the starting position is an odd + // index, then the last odd element will also be loaded as part + // of loading the last byte (high 4 bits of last byte). + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + } + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_16; + __m128i h_b0_16; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // The last int4 elem is already loaded in the previous + // register. Details given in comments about hmask_16. + __m128i h_b0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 ); + } +} + + +#define LOAD_16_COLS_AVX2 \ + a_reg[0] = _mm256_loadu_si256((__m256i const *)(b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); \ + a_reg[1] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 1 ) ) + kr) / 2 )); \ + a_reg[2] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 2 ) ) + kr) / 2 )); \ + a_reg[3] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 3 ) ) + kr) / 2 )); \ + a_reg[4] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 4 ) ) + kr) / 2 )); \ + a_reg[5] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 5 ) ) + kr) / 2 )); \ + a_reg[6] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 6 ) ) + kr) / 2 )); \ + a_reg[7] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 7 ) ) + kr) / 2 )); \ + a_reg[8] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 8 ) ) + kr) / 2 )); \ + a_reg[9] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 9 ) ) + kr) / 2 )); \ + a_reg[10] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 10 ) ) + kr) / 2 )); \ + a_reg[11] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 11 ) ) + kr) / 2 )); \ + a_reg[12] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 12 ) ) + kr) / 2 )); \ + a_reg[13] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 13 ) ) + kr) / 2 )); \ + a_reg[14] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 14 ) ) + kr) / 2 )); \ + a_reg[15] = _mm256_loadu_si256((__m256i const *) (b + ( ( ldb * ( jr + 15 ) ) + kr) / 2 )); + +#define UNPACKHILO8_AVX2 \ + b_reg[0] = _mm256_unpacklo_epi8(a_reg[0], a_reg[1]); \ + b_reg[2] = _mm256_unpacklo_epi8(a_reg[2], a_reg[3]); \ + b_reg[4] = _mm256_unpacklo_epi8(a_reg[4], a_reg[5]); \ + b_reg[6] = _mm256_unpacklo_epi8(a_reg[6], a_reg[7]); \ + b_reg[8] = _mm256_unpacklo_epi8(a_reg[8], a_reg[9]); \ + b_reg[10] = _mm256_unpacklo_epi8(a_reg[10], a_reg[11]); \ + b_reg[12] = _mm256_unpacklo_epi8(a_reg[12], a_reg[13]); \ + b_reg[14] = _mm256_unpacklo_epi8(a_reg[14], a_reg[15]); \ +\ + b_reg[1] = _mm256_unpackhi_epi8(a_reg[0], a_reg[1]); \ + b_reg[3] = _mm256_unpackhi_epi8(a_reg[2], a_reg[3]); \ + b_reg[5] = _mm256_unpackhi_epi8(a_reg[4], a_reg[5]); \ + b_reg[7] = _mm256_unpackhi_epi8(a_reg[6], a_reg[7]); \ + b_reg[9] = _mm256_unpackhi_epi8(a_reg[8], a_reg[9]); \ + b_reg[11] = _mm256_unpackhi_epi8(a_reg[10], a_reg[11]); \ + b_reg[13] = _mm256_unpackhi_epi8(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm256_unpackhi_epi8(a_reg[14], a_reg[15]); + +#define UNPACKHILO16_AVX2 \ + a_reg[0] = _mm256_unpacklo_epi16(b_reg[0], b_reg[2]); \ + a_reg[1] = _mm256_unpacklo_epi16(b_reg[4], b_reg[6]); \ + a_reg[2] = _mm256_unpacklo_epi16(b_reg[8], b_reg[10]); \ + a_reg[3] = _mm256_unpacklo_epi16(b_reg[12], b_reg[14]); \ + a_reg[4] = _mm256_unpacklo_epi16(b_reg[1], b_reg[3]); \ + a_reg[5] = _mm256_unpacklo_epi16(b_reg[5], b_reg[7]); \ + a_reg[6] = _mm256_unpacklo_epi16(b_reg[9], b_reg[11]); \ + a_reg[7] = _mm256_unpacklo_epi16(b_reg[13], b_reg[15]); \ +\ + a_reg[8] = _mm256_unpackhi_epi16(b_reg[0], b_reg[2]); \ + a_reg[9] = _mm256_unpackhi_epi16(b_reg[4], b_reg[6]); \ + a_reg[10] = _mm256_unpackhi_epi16(b_reg[8], b_reg[10]); \ + a_reg[11] = _mm256_unpackhi_epi16(b_reg[12], b_reg[14]); \ + a_reg[12] = _mm256_unpackhi_epi16(b_reg[1], b_reg[3]); \ + a_reg[13] = _mm256_unpackhi_epi16(b_reg[5], b_reg[7]); \ + a_reg[14] = _mm256_unpackhi_epi16(b_reg[9], b_reg[11]); \ + a_reg[15] = _mm256_unpackhi_epi16(b_reg[13], b_reg[15]); + +#define UNPACKHILO32_AVX2 \ + b_reg[0] = _mm256_unpacklo_epi32(a_reg[0], a_reg[1]); \ + b_reg[1] = _mm256_unpacklo_epi32(a_reg[2], a_reg[3]); \ + b_reg[2] = _mm256_unpacklo_epi32(a_reg[4], a_reg[5]); \ + b_reg[3] = _mm256_unpacklo_epi32(a_reg[6], a_reg[7]); \ + b_reg[4] = _mm256_unpacklo_epi32(a_reg[8], a_reg[9]); \ + b_reg[5] = _mm256_unpacklo_epi32(a_reg[10], a_reg[11]); \ + b_reg[6] = _mm256_unpacklo_epi32(a_reg[12], a_reg[13]); \ + b_reg[7] = _mm256_unpacklo_epi32(a_reg[14], a_reg[15]); \ +\ + b_reg[8] = _mm256_unpackhi_epi32(a_reg[0], a_reg[1]); \ + b_reg[9] = _mm256_unpackhi_epi32(a_reg[2], a_reg[3]); \ + b_reg[10] = _mm256_unpackhi_epi32(a_reg[4], a_reg[5]); \ + b_reg[11] = _mm256_unpackhi_epi32(a_reg[6], a_reg[7]); \ + b_reg[12] = _mm256_unpackhi_epi32(a_reg[8], a_reg[9]); \ + b_reg[13] = _mm256_unpackhi_epi32(a_reg[10], a_reg[11]); \ + b_reg[14] = _mm256_unpackhi_epi32(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm256_unpackhi_epi32(a_reg[14], a_reg[15]); + +#define UNPACKHILO64_AVX2 \ + a_reg[0] = _mm256_unpacklo_epi64(b_reg[0], b_reg[1]); \ + a_reg[1] = _mm256_unpacklo_epi64(b_reg[2], b_reg[3]); \ + a_reg[2] = _mm256_unpacklo_epi64(b_reg[4], b_reg[5]); \ + a_reg[3] = _mm256_unpacklo_epi64(b_reg[6], b_reg[7]); \ + a_reg[4] = _mm256_unpacklo_epi64(b_reg[8], b_reg[9]); \ + a_reg[5] = _mm256_unpacklo_epi64(b_reg[10], b_reg[11]); \ + a_reg[6] = _mm256_unpacklo_epi64(b_reg[12], b_reg[13]); \ + a_reg[7] = _mm256_unpacklo_epi64(b_reg[14], b_reg[15]); \ +\ + a_reg[8] = _mm256_unpackhi_epi64(b_reg[0], b_reg[1]); \ + a_reg[9] = _mm256_unpackhi_epi64(b_reg[2], b_reg[3]); \ + a_reg[10] = _mm256_unpackhi_epi64(b_reg[4], b_reg[5]); \ + a_reg[11] = _mm256_unpackhi_epi64(b_reg[6], b_reg[7]); \ + a_reg[12] = _mm256_unpackhi_epi64(b_reg[8], b_reg[9]); \ + a_reg[13] = _mm256_unpackhi_epi64(b_reg[10], b_reg[11]); \ + a_reg[14] = _mm256_unpackhi_epi64(b_reg[12], b_reg[13]); \ + a_reg[15] = _mm256_unpackhi_epi64(b_reg[14], b_reg[15]); + +#define MASK_LOAD_16_COLS_AVX2(mask) \ + a_reg[0] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); \ + a_reg[1] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 1 ) ) + kr) / 2 )); \ + a_reg[2] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 2 ) ) + kr) / 2 )); \ + a_reg[3] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 3 ) ) + kr) / 2 )); \ + a_reg[4] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 4 ) ) + kr) / 2 )); \ + a_reg[5] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 5 ) ) + kr) / 2 )); \ + a_reg[6] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 6 ) ) + kr) / 2 )); \ + a_reg[7] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 7 ) ) + kr) / 2 )); \ + a_reg[8] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 8 ) ) + kr) / 2 )); \ + a_reg[9] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 9 ) ) + kr) / 2 )); \ + a_reg[10] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 10 ) ) + kr) / 2 )); \ + a_reg[11] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 11 ) ) + kr) / 2 )); \ + a_reg[12] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 12 ) ) + kr) / 2 )); \ + a_reg[13] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 13 ) ) + kr) / 2 )); \ + a_reg[14] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 14 ) ) + kr) / 2 )); \ + a_reg[15] = _mm256_maskz_loadu_epi8( mask, (b + ( ( ldb * ( jr + 15 ) ) + kr) / 2 )); + +void packb_nr_mult_16_bf16s4f32of32_col_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t NR, + const dim_t ldb, + const dim_t KC + ) +{ + // Used for storing the mm256i elements for use in dpbf16_ps instruction. + __mmask8 msk0 = _cvtu32_mask8(0x0F); + __mmask8 msk1 = _cvtu32_mask8(0xF0); + + __m256i a_reg[16]; + __m256i b_reg[16]; + + dim_t kr = 0; + for (kr= 0; ( kr + 63 ) < KC; kr += 64 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + LOAD_16_COLS_AVX2 + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 16 ) * NR))/2 )), msk0, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 18 ) * NR))/2 )), msk0, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 20 ) * NR))/2 )), msk0, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 22 ) * NR))/2 )), msk0, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 24 ) * NR))/2 )), msk0, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 26 ) * NR))/2 )), msk0, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 28 ) * NR))/2 )), msk0, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 30 ) * NR))/2 )), msk0, a_reg[15] ); + + // The 16 value decrement is to correct the masked store starting postion with respect to the msk1. + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 32 ) * NR))/2 - 16)), msk1, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 34 ) * NR))/2 - 16)), msk1, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 36 ) * NR))/2 - 16)), msk1, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 38 ) * NR))/2 - 16)), msk1, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 40 ) * NR))/2 - 16)), msk1, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 42 ) * NR))/2 - 16)), msk1, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 44 ) * NR))/2 - 16)), msk1, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 46 ) * NR))/2 - 16)), msk1, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 48 ) * NR))/2 - 16)), msk1, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 50 ) * NR))/2 - 16)), msk1, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 52 ) * NR))/2 - 16)), msk1, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 54 ) * NR))/2 - 16)), msk1, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 56 ) * NR))/2 - 16)), msk1, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 58 ) * NR))/2 - 16)), msk1, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 60 ) * NR))/2 - 16)), msk1, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 62 ) * NR))/2 - 16)), msk1, a_reg[15] ); + } + } + + for ( ; ( kr + 31 ) < KC; kr += 32 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX2(0x0000FFFF) + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + //store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 16 ) * NR))/2 )), msk0, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 18 ) * NR))/2 )), msk0, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 20 ) * NR))/2 )), msk0, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 22 ) * NR))/2 )), msk0, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 24 ) * NR))/2 )), msk0, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 26 ) * NR))/2 )), msk0, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 28 ) * NR))/2 )), msk0, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 30 ) * NR))/2 )), msk0, a_reg[15] ); + } + } + + for ( ; ( kr + 15 ) < KC; kr += 16 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX2(0x000000FF) + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + } + } + + for ( ; ( kr + 7 ) < KC; kr += 8 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX2(0x0F) + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + } + } + + for ( ; ( kr + 3 ) < KC; kr += 4 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX2(0x03) + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + } + } + + for ( ; ( kr + 1 ) < KC; kr += 2 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 16 cols from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX2(0x01) + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((jr * 2) + (( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + } + } +} + +void packb_nrlt16_bf16s4f32of32_col_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + dim_t NR = 16; + + // Used for storing the mm256i elements for use in dpbf16_ps instruction. + __mmask8 msk0 = _cvtu32_mask8(0x0F); + __mmask8 msk1 = _cvtu32_mask8(0xF0); + + __m256i a_reg[16]; + __m256i b_reg[16]; + + dim_t kr = 0, jr = 0; + for ( kr = 0; ( kr + 63 ) < KC; kr += 64 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_loadu_si256((__m256i const *)(b + ( ( ldb * jr ) + kr) / 2 )); + } + for(; jr < NR; jr++) + { + a_reg[jr] = _mm256_setzero_si256(); + } + + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 16 ) * NR))/2 )), msk0, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 18 ) * NR))/2 )), msk0, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 20 ) * NR))/2 )), msk0, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 22 ) * NR))/2 )), msk0, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 24 ) * NR))/2 )), msk0, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 26 ) * NR))/2 )), msk0, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 28 ) * NR))/2 )), msk0, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 30 ) * NR))/2 )), msk0, a_reg[15] ); + + // The 16 value decrement is to correct the masked store starting postion with respect to the msk1. + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 32 ) * NR))/2 - 16)), msk1, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 34 ) * NR))/2 - 16)), msk1, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 36 ) * NR))/2 - 16)), msk1, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 38 ) * NR))/2 - 16)), msk1, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 40 ) * NR))/2 - 16)), msk1, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 42 ) * NR))/2 - 16)), msk1, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 44 ) * NR))/2 - 16)), msk1, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 46 ) * NR))/2 - 16)), msk1, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 48 ) * NR))/2 - 16)), msk1, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 50 ) * NR))/2 - 16)), msk1, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 52 ) * NR))/2 - 16)), msk1, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 54 ) * NR))/2 - 16)), msk1, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 56 ) * NR))/2 - 16)), msk1, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 58 ) * NR))/2 - 16)), msk1, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 60 ) * NR))/2 - 16)), msk1, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 62 ) * NR))/2 - 16)), msk1, a_reg[15] ); + } + + for ( ; ( kr + 31 ) < KC; kr += 32 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_maskz_loadu_epi8( 0x0000FFFF, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm256_setzero_si256(); + } + + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 16 ) * NR))/2 )), msk0, a_reg[1] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 18 ) * NR))/2 )), msk0, a_reg[9] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 20 ) * NR))/2 )), msk0, a_reg[5] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 22 ) * NR))/2 )), msk0, a_reg[13] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 24 ) * NR))/2 )), msk0, a_reg[3] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 26 ) * NR))/2 )), msk0, a_reg[11] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 28 ) * NR))/2 )), msk0, a_reg[7] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 30 ) * NR))/2 )), msk0, a_reg[15] ); + } + + for ( ; ( kr + 15 ) < KC; kr += 16 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_maskz_loadu_epi8( 0xFF, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); \ + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm256_setzero_si256(); + } + + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 8 ) * NR))/2 )), msk0, a_reg[2] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 10 ) * NR))/2 )), msk0, a_reg[10] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 12 ) * NR))/2 )), msk0, a_reg[6] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 14 ) * NR))/2 )), msk0, a_reg[14] ); + } + + for ( ; ( kr + 7 ) < KC; kr += 8 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_maskz_loadu_epi8( 0x0F, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); \ + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm256_setzero_si256(); + } + + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 4 ) * NR))/2 )), msk0, a_reg[4] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 6 ) * NR))/2 )), msk0, a_reg[12] ); + } + + for ( ; (kr+3) < KC; kr += 4 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_maskz_loadu_epi8( 0x03, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); \ + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm256_setzero_si256(); + } + + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 2 ) * NR))/2 )), msk0, a_reg[8] ); + } + + for ( ; ( kr + 1 ) < KC; kr += 2 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read n0_partial_rem cols from B with 64 elements in each row + a_reg[jr] = _mm256_maskz_loadu_epi8( 0x01, (b + ( ( ldb * ( jr + 0 ) ) + kr) / 2 )); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm256_setzero_si256(); + } + UNPACKHILO8_AVX2 + UNPACKHILO16_AVX2 + UNPACKHILO32_AVX2 + UNPACKHILO64_AVX2 + + // store to pack_b buffer + _mm256_mask_storeu_epi32( ((pack_b_buffer + ((( kr + 0 ) * NR))/2 )), msk0, a_reg[0] ); + } +} + + +void packb_nr64_bf16s4f32of32_col_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b, + lpgemm_pre_op* pre_op + ) +{ + dim_t NR = 64; + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + + + dim_t n_partial_pieces = NC % NR; + dim_t k_partial_pieces = KC % 2; + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + packb_nr_mult_16_bf16s4f32of32_col_major + ( + ( pack_b_buffer + ((jc* KC_updated)/2)) , (b + (jc*ldb)/2), 64, ldb, KC + ); + } + + if(n_partial_pieces > 0) + { + + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr_mult_16_bf16s4f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated )/2 ), + ( b + (n_full_pieces_loop_limit * ldb )/2), 48, ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr_mult_16_bf16s4f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated )/2 ), + ( b + (n_full_pieces_loop_limit * ldb)/2 ), 32, ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr_mult_16_bf16s4f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated )/2 ), + ( b + (n_full_pieces_loop_limit * ldb)/2 ), 16, ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16s4f32of32_col_major + ( + ( pack_b_buffer + (( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ))/2 ), + ( b + (( n_full_pieces_loop_limit + n0_partial_pack ) * ldb)/2 ), ldb, KC, + n0_partial_rem + ); + } + } + + *rs_b = NR * 2; + *cs_b = NR / 2; +} + + +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_s4_to_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_s4_to_bf16_amd512vnni.c new file mode 100644 index 0000000000..c35ba29327 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_s4_to_bf16_amd512vnni.c @@ -0,0 +1,857 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../int4_utils_avx512.h" + +#ifdef LPGEMM_BF16_JIT + +void packsclb_nr64_bf16s4f32of32( + bfloat16 *packb_bf16, + const int8_t *b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p, + lpgemm_pre_op *b_pre_ops, + dim_t pre_op_off) +{ + //This bf16 packB_s4_bf16 is Not supported for gcc<11.2 +} + +#else //LPGEMM_BF16_JIT +/* +input:__m512i containing 64 int8 elements +output: two __m512 containing 16 f32 elements +*/ +#define CVT_INT8_F32_SCAL_16( in, idx, scale_reg) \ + (_mm512_mul_ps( \ + _mm512_cvtepi32_ps( \ + _mm512_cvtepi8_epi32( \ + _mm512_extracti32x4_epi32( in, idx ) ) ), scale_reg ) ) + +void packsclb_nr48_bf16s4f32of32 +( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t KC, + bool signed_upscale, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off +) +{ + dim_t NR = 48; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + /* Regs to load int4 elements */ + __m256i ymm0, ymm1; + /* Regs to store zero-point values */ + __m512i zero_point, zero_point0, zero_point1; + /* Regs to store scale factor values */ + __m512 zmm4, zmm5, zmm6, zmm7, zmm8, zmm9; + /* Regs to store intermediate int8 elements */ + __m512i zmm14, zmm15; + /* Regs to store bf16 values */ + __m512bh zmm0, zmm1, zmm2; + /* Regs to store masks */ + __m512i mask_zp1, mask_zp2, mask_scale1, mask_scale2; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + mask_zp1 = _mm512_set_epi64( 0x5F1F5E1E5D1D5C1C, 0x5B1B5A1A59195818, + 0x5717561655155414, 0x5313521251115010, + 0x4F0F4E0E4D0D4C0C, 0x4B0B4A0A49094808, + 0x4707460645054404, 0x4303420241014000 ); + + mask_zp2 = _mm512_set_epi64( 0x7F3F7E3E7D3D7C3C, 0x7B3B7A3A79397838, + 0x7737763675357434, 0x7333723271317030, + 0x6F2F6E2E6D2D6C2C, 0x6B2B6A2A69296828, + 0x6727662665256424, 0x6323622261216020 ); + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( b_pre_ops->zp_len > 1 ) + { + zero_point = _mm512_maskz_loadu_epi8( 0xFFFFFFFFFFFF, ( b_pre_ops->zp + + pre_op_off ) ); + } + else + { + zero_point = _mm512_set1_epi8( *( ( int8_t* )b_pre_ops->zp ) ); + } + zero_point1 = _mm512_permutex2var_epi8( zero_point, mask_zp2, zero_point ); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_point ); + + if( b_pre_ops->scale_factor_len > 1 ) + { + zmm4 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off ); + zmm6 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + 16 ); + zmm8 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + 32 ); + + zmm5 = _mm512_permutex2var_ps( zmm4, mask_scale2, zmm4 ); + zmm4 = _mm512_permutex2var_ps( zmm4, mask_scale1, zmm4 ); + zmm7 = _mm512_permutex2var_ps( zmm6, mask_scale2, zmm6 ); + zmm6 = _mm512_permutex2var_ps( zmm6, mask_scale1, zmm6 ); + zmm9 = _mm512_permutex2var_ps( zmm8, mask_scale2, zmm8 ); + zmm8 = _mm512_permutex2var_ps( zmm8, mask_scale1, zmm8 ); + } + else + { + zmm4 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm5 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm6 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm7 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm8 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm9 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + } + + for( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + ymm0 = _mm256_loadu_si256((__m256i const* )(b + ( kr * NR ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + ymm1 = _mm256_maskz_loadu_epi8(0xFFFF, (__m128i const* )(b + + ( kr * NR + 64 ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm1, zmm15, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm15 = _mm512_sub_epi8( zmm15, zero_point1 ); + + zmm2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 1, zmm9), + CVT_INT8_F32_SCAL_16( zmm15, 0, zmm8) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ), (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ) + 32, + (__m512i)zmm1 ); + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ) + 64, + (__m512i)zmm2 ); + } + /* Handle k remainder. */ + if( k_partial_pieces > 0 ) + { + __m512i zero_reg = _mm512_setzero_si512(); + zero_point1 = _mm512_permutex2var_epi8( zero_point, mask_zp2, zero_reg ); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_reg ); + + ymm0 = _mm256_loadu_si256((__m256i const* )(b + ( k_full_pieces + 0 ) + * NR / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + ymm1 = _mm256_maskz_loadu_epi8( 0xFFFF, (__m128i const* )(b + + ( k_full_pieces * NR + 64 ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm1, zmm15, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm15 = _mm512_sub_epi8( zmm15, zero_point1 ); + + zmm2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 1, zmm9), + CVT_INT8_F32_SCAL_16( zmm15, 0, zmm8) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ), + (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ) + 32, + (__m512i)zmm1 ); + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ) + 64, + (__m512i)zmm2 ); + } +} + + +void packsclb_nr32_bf16s4f32of32 +( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t KC, + bool signed_upscale, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off +) +{ + dim_t NR = 32; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + /* Regs to load int4 elements */ + __m256i ymm0; + /* Regs to store zero-point values */ + __m512i zero_point, zero_point0; + /* Regs to store scale factor values */ + __m512 zmm4, zmm5, zmm6, zmm7; + /* Regs to store intermediate int8 elements */ + __m512i zmm14; + /* Regs to store bf16 values */ + __m512bh zmm0, zmm1; + /* Regs to store masks */ + __m512i mask_zp1, mask_scale1, mask_scale2; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + mask_zp1 = _mm512_set_epi64( 0x5F1F5E1E5D1D5C1C, 0x5B1B5A1A59195818, + 0x5717561655155414, 0x5313521251115010, + 0x4F0F4E0E4D0D4C0C, 0x4B0B4A0A49094808, + 0x4707460645054404, 0x4303420241014000 ); + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( b_pre_ops->zp_len > 1 ) + { + zero_point = _mm512_maskz_loadu_epi8( 0xFFFFFFFF, ( b_pre_ops->zp + + pre_op_off ) ); + } + else + { + zero_point = _mm512_set1_epi8( *( ( int8_t* )b_pre_ops->zp ) ); + } + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_point ); + + if( b_pre_ops->scale_factor_len > 1 ) + { + zmm4 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off ); + zmm6 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + 16 ); + + zmm5 = _mm512_permutex2var_ps( zmm4, mask_scale2, zmm4 ); + zmm4 = _mm512_permutex2var_ps( zmm4, mask_scale1, zmm4 ); + zmm7 = _mm512_permutex2var_ps( zmm6, mask_scale2, zmm6 ); + zmm6 = _mm512_permutex2var_ps( zmm6, mask_scale1, zmm6 ); + } + else + { + zmm4 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm5 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm6 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm7 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + } + + for( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + ymm0 = _mm256_loadu_si256((__m256i const* )(b + ( kr * NR ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ), (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ) + 32, + (__m512i)zmm1 ); + } + /* Handle k remainder. */ + if( k_partial_pieces > 0 ) + { + __m512i zero_reg = _mm512_setzero_si512(); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_reg ); + + ymm0 = _mm256_loadu_si256((__m256i const* )(b + ( k_full_pieces + 0 ) + * NR / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ), + (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ) + 32, + (__m512i)zmm1 ); + } +} + + +void packsclb_nr16_bf16s4f32of32 +( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t KC, + bool signed_upscale, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off +) +{ + dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + /* Regs to load int4 elements */ + __m256i ymm0; + /* Regs to store zero-point values */ + __m512i zero_point, zero_point0; + /* Regs to store scale factor values */ + __m512 zmm4, zmm5; + /* Regs to store intermediate int8 elements */ + __m512i zmm14; + /* Regs to store bf16 values */ + __m512bh zmm0; + /* Regs to store masks */ + __m512i mask_zp1, mask_scale1, mask_scale2; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + mask_zp1 = _mm512_set_epi64( 0x5F1F5E1E5D1D5C1C, 0x5B1B5A1A59195818, + 0x5717561655155414, 0x5313521251115010, + 0x4F0F4E0E4D0D4C0C, 0x4B0B4A0A49094808, + 0x4707460645054404, 0x4303420241014000 ); + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( b_pre_ops->zp_len > 1 ) + { + zero_point = _mm512_maskz_loadu_epi8( 0xFFFF, ( b_pre_ops->zp + + pre_op_off ) ); + } + else + { + zero_point = _mm512_set1_epi8( *( ( int8_t* )b_pre_ops->zp ) ); + } + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_point ); + + if( b_pre_ops->scale_factor_len > 1 ) + { + zmm4 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off ); + zmm5 = _mm512_permutex2var_ps( zmm4, mask_scale2, zmm4 ); + zmm4 = _mm512_permutex2var_ps( zmm4, mask_scale1, zmm4 ); + } + else + { + zmm4 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm5 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + } + + for( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + ymm0 = _mm256_maskz_loadu_epi8( 0xFFFF, (__m256i const* )(b + + ( kr * NR ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( kr + 0 ) * NR ), (__m512i)zmm0 ); + } + /* Handle k remainder. */ + if( k_partial_pieces > 0 ) + { + __m512i zero_reg = _mm512_setzero_si512(); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_reg ); + + ymm0 = _mm256_maskz_loadu_epi8( 0xFFFF, (__m256i const* )(b + + ( k_full_pieces * NR ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ), + (__m512i)zmm0 ); + } +} + + +void packsclb_nrlt16_bf16s4f32of32 +( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t KC, + const dim_t n_rem, + bool signed_upscale, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off +) +{ + dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + /* Regs to load int4 elements */ + __m256i ymm0; + /* Regs to store zero-point values */ + __m512i zero_point, zero_point0; + /* Regs to store scale factor values */ + __m512 zmm4, zmm5; + /* Regs to store intermediate int8 elements */ + __m512i zmm14; + /* Regs to store bf16 values */ + __m512bh zmm0; + /* Regs to store masks */ + __m512i mask_zp1, mask_scale1, mask_scale2; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + __m512i sign_comp = _mm512_set1_epi8(0x08); + + __mmask16 lmask = _cvtu32_mask16( 0xFFFF >> ( 16 - n_rem ) ); + + mask_zp1 = _mm512_set_epi64( 0x5F1F5E1E5D1D5C1C, 0x5B1B5A1A59195818, + 0x5717561655155414, 0x5313521251115010, + 0x4F0F4E0E4D0D4C0C, 0x4B0B4A0A49094808, + 0x4707460645054404, 0x4303420241014000 ); + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + if( b_pre_ops->zp_len > 1 ) + { + zero_point = _mm512_maskz_loadu_epi8( lmask, ( b_pre_ops->zp + + pre_op_off ) ); + } + else + { + zero_point = _mm512_set1_epi8( *( ( int8_t* )b_pre_ops->zp ) ); + } + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_point ); + + if( b_pre_ops->scale_factor_len > 1 ) + { + zmm4 = _mm512_maskz_loadu_ps( lmask, (float*)( b_pre_ops->scale_factor ) + + pre_op_off ); + zmm5 = _mm512_permutex2var_ps( zmm4, mask_scale2, zmm4 ); + zmm4 = _mm512_permutex2var_ps( zmm4, mask_scale1, zmm4 ); + } + else + { + zmm4 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm5 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + } + + for( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + ymm0 = _mm256_maskz_loadu_epi8( lmask, (__m256i const* )(b + + ( kr * NR ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + //store to pack_b buffer + _mm512_mask_storeu_epi32( packb_bf16 + ( ( kr + 0 ) * NR ), + lmask, (__m512i)zmm0 ); + } + /* Handle k remainder. */ + if( k_partial_pieces > 0 ) + { + __m512i zero_reg = _mm512_setzero_si512(); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_reg ); + + ymm0 = _mm256_maskz_loadu_epi8(lmask, (__m256i const* )(b + ( k_full_pieces + 0 ) + * NR / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + //store to pack_b buffer + _mm512_mask_storeu_epi32( packb_bf16 + ( ( k_full_pieces + 0 ) * NR ), + lmask, (__m512i)zmm0 ); + } +} + + +void packsclb_nr64_bf16s4f32of32 + ( + bfloat16* packb_bf16, + const int8_t* b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p, + lpgemm_pre_op* b_pre_ops, + dim_t pre_op_off + ) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + bool signed_upscale = true; + + /* Regs to store bf16 elems */ + __m512bh zmm0, zmm1, zmm2, zmm3; + /* Regs to store F32 scale */ + __m512 zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + /* Regs to store int8 elems zero-point values */ + __m512i zero_point, zero_point0, zero_point1; + /* Reg to load int4 data */ + __m256i ymm0, ymm1; + /* Reg to store intermediate int8 elements */ + __m512i zmm14, zmm15; + /* Reg to store masks to interleave scale factor */ + __m512i mask_scale1, mask_scale2; + /* Regs to store masks to interleave zero_point values */ + __m512i mask_zp1, mask_zp2; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + + mask_zp1 = _mm512_set_epi64( 0x5F1F5E1E5D1D5C1C, 0x5B1B5A1A59195818, + 0x5717561655155414, 0x5313521251115010, + 0x4F0F4E0E4D0D4C0C, 0x4B0B4A0A49094808, + 0x4707460645054404, 0x4303420241014000 ); + + mask_zp2 = _mm512_set_epi64( 0x7F3F7E3E7D3D7C3C, 0x7B3B7A3A79397838, + 0x7737763675357434, 0x7333723271317030, + 0x6F2F6E2E6D2D6C2C, 0x6B2B6A2A69296828, + 0x6727662665256424, 0x6323622261216020 ); + + mask_scale1 = _mm512_set_epi32( 0x17, 0x07, 0x16, 0x06, 0x15, 0x05, 0x14, + 0x04, 0x13, 0x03, 0x12, 0x02, 0x11, 0x01, + 0x10, 0x00 ); + + mask_scale2 = _mm512_set_epi32( 0x1F, 0x0F, 0x1E, 0x0E, 0x1D, 0x0D, 0x1C, + 0x0C, 0x1B, 0x0B, 0x1A, 0x0A, 0x19, 0x09, + 0x18, 0x08); + + __m512i sign_comp = _mm512_set1_epi8(0x08); + + for( dim_t jr = 0; jr < n_full_pieces_loop_limit; jr += NR ) + { + if( b_pre_ops->zp_len > 1 ) + { + zero_point = _mm512_loadu_si512( ( b_pre_ops->zp ) + + pre_op_off + jr ); + } + else + { + zero_point = _mm512_set1_epi8( *( ( int8_t* )b_pre_ops->zp ) ); + } + /* interleave zero-point values */ + zero_point1 = _mm512_permutex2var_epi8( zero_point, mask_zp2, zero_point ); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_point ); + + if( b_pre_ops->scale_factor_len > 1 ) + { + // load and interleave scale factor vectors + zmm4 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + jr); + zmm6 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + jr + 16 ); + zmm8 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + jr + 32 ); + zmm10 = _mm512_loadu_ps( (float*)( b_pre_ops->scale_factor ) + + pre_op_off + jr + 48 ); + + zmm5 = _mm512_permutex2var_ps( zmm4, mask_scale2, zmm4 ); + zmm4 = _mm512_permutex2var_ps( zmm4, mask_scale1, zmm4 ); + zmm7 = _mm512_permutex2var_ps( zmm6, mask_scale2, zmm6 ); + zmm6 = _mm512_permutex2var_ps( zmm6, mask_scale1, zmm6 ); + zmm9 = _mm512_permutex2var_ps( zmm8, mask_scale2, zmm8 ); + zmm8 = _mm512_permutex2var_ps( zmm8, mask_scale1, zmm8 ); + zmm11 = _mm512_permutex2var_ps( zmm10, mask_scale2, zmm10 ); + zmm10 = _mm512_permutex2var_ps( zmm10, mask_scale1, zmm10 ); + + } + else + { + zmm4 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm5 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm6 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm7 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm8 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm9 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm10 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + zmm11 = _mm512_set1_ps( *( ( float* )b_pre_ops->scale_factor ) ); + } + for( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + + ymm0 = _mm256_loadu_si256( (__m256i const *)(b + ( ( jr * KC_updated ) + + ( ( kr + 0 ) * NR ) ) / 2 ) ); + ymm1 = _mm256_loadu_si256( (__m256i const *)(b + ( ( jr * KC_updated ) + + ( ( kr + 1 ) * NR ) ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm1, zmm15, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm15 = _mm512_sub_epi8( zmm15, zero_point1 ); + + zmm2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 1, zmm9), + CVT_INT8_F32_SCAL_16( zmm15, 0, zmm8) ); + zmm3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 3, zmm11), + CVT_INT8_F32_SCAL_16( zmm15, 2, zmm10) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( kr + 0 ) * NR ), (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( kr + 0 ) * NR ) + 32, (__m512i)zmm1 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( kr + 1 ) * NR ), (__m512i)zmm2 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( kr + 1 ) * NR ) + 32, (__m512i)zmm3 ); + + } + // Handle k remainder. + if( k_partial_pieces > 0 ) + { + __m512i zero_reg = _mm512_setzero_si512(); + + /* Interleave zero_point values with zeroes */ + zero_point1 = _mm512_permutex2var_epi8( zero_point, mask_zp2, zero_reg ); + zero_point0 = _mm512_permutex2var_epi8( zero_point, mask_zp1, zero_reg ); + + ymm0 = _mm256_loadu_si256( (__m256i const *)(b + ( ( jr * KC_updated ) + + ( ( k_full_pieces + 0 ) * NR ) ) / 2 ) ); + ymm1 = _mm256_loadu_si256( (__m256i const *)(b + ( ( jr * KC_updated ) + + ( ( k_full_pieces + 1 ) * NR ) ) / 2 ) ); + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm0, zmm14, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm14 = _mm512_sub_epi8( zmm14, zero_point0 ); + + zmm0 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 1, zmm5), + CVT_INT8_F32_SCAL_16( zmm14, 0, zmm4) ); + zmm1 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm14, 3, zmm7), + CVT_INT8_F32_SCAL_16( zmm14, 2, zmm6) ); + + + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT( ymm1, zmm15, shift_idx_64, \ + sign_comp, signed_upscale); + + zmm15 = _mm512_sub_epi8( zmm15, zero_point1 ); + + zmm2 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 1, zmm9), + CVT_INT8_F32_SCAL_16( zmm15, 0, zmm8) ); + zmm3 = _mm512_cvtne2ps_pbh( CVT_INT8_F32_SCAL_16( zmm15, 3, zmm11), + CVT_INT8_F32_SCAL_16( zmm15, 2, zmm10) ); + + //store to pack_b buffer + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( k_full_pieces + 0 ) * NR ), (__m512i)zmm0 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( k_full_pieces + 0 ) * NR ) + 32, (__m512i)zmm1 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( k_full_pieces + 1 ) * NR ), (__m512i)zmm2 ); + _mm512_storeu_si512( packb_bf16 + ( jr * KC_updated ) + + ( ( k_full_pieces + 1 ) * NR ) + 32, (__m512i)zmm3 ); + } + } + + if( n_partial_pieces > 0 ) + { + pre_op_off += n_full_pieces_loop_limit; + + // Handle NR edge cases + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packsclb_nr48_bf16s4f32of32 + ( + ( packb_bf16 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit * KC_updated / 2 ) ), KC, + signed_upscale, b_pre_ops, pre_op_off + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packsclb_nr32_bf16s4f32of32 + ( + ( packb_bf16 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit * KC_updated / 2 ) ), KC, + signed_upscale, b_pre_ops, pre_op_off + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packsclb_nr16_bf16s4f32of32 + ( + ( packb_bf16 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit * KC_updated / 2 ) ), KC, + signed_upscale, b_pre_ops, pre_op_off + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + pre_op_off += n0_partial_pack; + packsclb_nrlt16_bf16s4f32of32 + ( + ( packb_bf16 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + ( ( n_full_pieces_loop_limit + n0_partial_pack ) * KC_updated / 2 ) ), + KC, n0_partial_rem, signed_upscale, b_pre_ops, pre_op_off + ); + } + } + + *rs_p = NR * 2; + *cs_p = NR / 2; +} + +#endif // LPGEMM_BF16_JIT +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c new file mode 100644 index 0000000000..09d1cd9d71 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c @@ -0,0 +1,739 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" + + +#ifdef LPGEMM_BF16_JIT +LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) +{} +#else + + +LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64, + &&POST_OPS_MATRIX_MUL_6x64 + }; + + + // Strides are updated based on matrix packing/reordering. + const bfloat16 *a_use = NULL; + const bfloat16 *b_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for( dim_t jr = 0; jr < n0; jr += NR ) + { + + float* c_use = c + jr * cs_c; + + dim_t n_left = n0 - jr; + + NR = bli_min( NR, ( n_left >> 4 ) << 4 ); + + if( NR == 0 ) NR = 16; + + rs_b = NR * 2; + + dim_t nr0 = bli_min( n0 - jr, NR ); + + __mmask16 k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF, k4 = 0xFFFF; + __mmask32 k5 = 0xFFFFFFFF, k6 = 0xFFFFFFFF; + __mmask32 k7 = 0xFFFFFFFF, k8 = 0xFFFFFFFF; + + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k4 = k8 = 0x0; + } + else if( nr0 == 32 ) + { + k3 = k4 = k7 = k8 = 0x0; + } + else if( nr0 == 16 ) + { + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + else if( nr0 < 16 ) + { + k1 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + + __m512bh zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512 zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512 zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512 zmm22, zmm23; + __m512bh zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + + // zero the accumulator registers + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + + // kc0 needs to be a multiple of 2 so that it can be + // used with dpbf16_ps instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + uint64_t k_iter = kc0 / 8; + uint64_t k_rem = ( kc0 / 2) % 4; + + // No parallelization in k dim, k always starts at 0. + + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( n_sub_updated * pc ) + + ( ( jc_cur_loop_rem + jr ) * kc0_updated ) ; + + a_use = a + pc; + + for (dim_t k = 0; k < k_iter; k++) + { + + // load first 4x32 tile from row 0-3 + zmm0 = (__m512bh)_mm512_maskz_loadu_epi16( k5, b_use ); + zmm1 = (__m512bh)_mm512_maskz_loadu_epi16( k5, b_use + rs_b ); + zmm2 = (__m512bh)_mm512_maskz_loadu_epi16( k5, + b_use + 2 * rs_b ); + zmm3 = (__m512bh)_mm512_maskz_loadu_epi16( k5, + b_use + 3 * rs_b ); + b_use += 32; + + // Broadcast col0-col3 elements of A + zmm4 = (__m512bh)_mm512_set1_epi32(*( int32_t* )( a_use ) ); + zmm5 = (__m512bh)_mm512_set1_epi32(*( int32_t* )( a_use + + ( cs_a ) ) ); + zmm6 = (__m512bh)_mm512_set1_epi32(*( int32_t* )( a_use + + ( cs_a * 2 ) ) ); + zmm7 = (__m512bh)_mm512_set1_epi32(*( int32_t* )( a_use + + ( cs_a * 3 ) ) ); + + // Load second 4x32 tile from row 0-3 + zmm24 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, b_use ); + zmm25 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, b_use + rs_b ); + zmm26 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, + b_use + 2 * rs_b ); + zmm27 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, + b_use + 3 * rs_b ); + b_use += 32; + + zmm8 = _mm512_dpbf16_ps( zmm8, zmm4, zmm0 ); + zmm9 = _mm512_dpbf16_ps( zmm9, zmm5, zmm1 ); + zmm10 = _mm512_dpbf16_ps( zmm10, zmm6, zmm2 ); + zmm11 = _mm512_dpbf16_ps( zmm11, zmm7, zmm3 ); + + // load third 4x32 tile from row 0-3 + zmm0 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, b_use ); + zmm1 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, b_use + rs_b ); + zmm2 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, + b_use + 2 * rs_b ); + zmm3 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, + b_use + 3 * rs_b ); + b_use += 32; + + + zmm12 = _mm512_dpbf16_ps( zmm12, zmm4, zmm24 ); + zmm13 = _mm512_dpbf16_ps( zmm13, zmm5, zmm25 ); + zmm14 = _mm512_dpbf16_ps( zmm14, zmm6, zmm26 ); + zmm15 = _mm512_dpbf16_ps( zmm15, zmm7, zmm27 ); + + // Load fourth 4x32 tile from row 0-3 + zmm28 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, b_use ); + zmm29 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, b_use + rs_b ); + zmm30 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, + b_use + 2 * rs_b ); + zmm31 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, + b_use + 3 * rs_b ); + + + zmm16 = _mm512_dpbf16_ps( zmm16, zmm4, zmm0 ); + zmm17 = _mm512_dpbf16_ps( zmm17, zmm5, zmm1 ); + zmm18 = _mm512_dpbf16_ps( zmm18, zmm6, zmm2 ); + zmm19 = _mm512_dpbf16_ps( zmm19, zmm7, zmm3 ); + + zmm20 = _mm512_dpbf16_ps( zmm20, zmm4, zmm28 ); + zmm21 = _mm512_dpbf16_ps( zmm21, zmm5, zmm29 ); + zmm22 = _mm512_dpbf16_ps( zmm22, zmm6, zmm30 ); + zmm23 = _mm512_dpbf16_ps( zmm23, zmm7, zmm31 ); + + b_use -= 96; // move b point back to start of KCXNR + b_use += (4 * rs_b); + a_use += 4 * cs_a; // move a pointer to next col + + } + + for (dim_t kr = 0; kr < k_rem; kr++) + { + // load 128 elements from a row of B + zmm0 = (__m512bh)_mm512_maskz_loadu_epi16 ( k5, b_use ); + zmm1 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, + b_use + cs_b ); + zmm2 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, + b_use + cs_b*2 ); + zmm3 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, + b_use + cs_b*3 ); + + // Broadcast col0 elements of A + zmm4 = (__m512bh)_mm512_set1_epi32(*( int32_t* )(a_use ) ); + + zmm8 = _mm512_dpbf16_ps( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbf16_ps( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbf16_ps( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbf16_ps( zmm20, zmm4, zmm3 ); + + b_use += rs_b; + a_use += cs_a; + } + if( kc0 & 1 ) + { + // load 128 elements from a row of B + zmm0 = (__m512bh)_mm512_maskz_loadu_epi16 ( k5, b_use ); + zmm1 = (__m512bh)_mm512_maskz_loadu_epi16 ( k6, b_use + cs_b ); + zmm2 = (__m512bh)_mm512_maskz_loadu_epi16 ( k7, + b_use + cs_b*2 ); + zmm3 = (__m512bh)_mm512_maskz_loadu_epi16 ( k8, + b_use + cs_b*3 ); + + // Broadcast col0 elements of A + zmm4 = (__m512bh)_mm512_set1_epi16(*(int16_t*) a_use ); + + zmm8 = _mm512_dpbf16_ps( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbf16_ps( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbf16_ps( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbf16_ps( zmm20, zmm4, zmm3 ); + + } + } + // Sumup k-unroll outputs + zmm8 = _mm512_add_ps( zmm9, zmm8 ); + zmm10 = _mm512_add_ps(zmm11, zmm10); + zmm8 = _mm512_add_ps(zmm10, zmm8); // 32 outputs + + zmm12 = _mm512_add_ps(zmm13, zmm12); + zmm14 = _mm512_add_ps(zmm15, zmm14); + zmm12 = _mm512_add_ps(zmm14, zmm12); // 32 outputs + + zmm16 = _mm512_add_ps(zmm17, zmm16); + zmm18 = _mm512_add_ps(zmm19, zmm18); + zmm16 = _mm512_add_ps(zmm18, zmm16); // 32 outputs + + zmm20 = _mm512_add_ps(zmm21, zmm20); + zmm22 = _mm512_add_ps(zmm23, zmm22); + zmm20 = _mm512_add_ps(zmm22, zmm20); // 32 outputs + + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mul_ps(selector1, zmm8); + zmm12 = _mm512_mul_ps(selector1, zmm12); + zmm16 = _mm512_mul_ps(selector1, zmm16); + zmm20 = _mm512_mul_ps(selector1, zmm20); + + if (beta != 0) + { + + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( post_ops_attr.buf_downscale != NULL ) + { + BF16_F32_BETA_OP_NLT16F_MASK( k1, zmm8, 0, 0, selector1, selector2 ) + BF16_F32_BETA_OP_NLT16F_MASK( k2, zmm12, 0, 1, selector1, selector2 ) + BF16_F32_BETA_OP_NLT16F_MASK( k3, zmm16, 0, 2, selector1, selector2 ) + BF16_F32_BETA_OP_NLT16F_MASK( k4, zmm20, 0, 3, selector1, selector2 ) + } + else + { + F32_F32_BETA_OP_NLT16F_MASK( c_use, k1, zmm8, 0, 0, 0, selector1, selector2 ) + F32_F32_BETA_OP_NLT16F_MASK( c_use, k2, zmm12, 0, 0, 1, selector1, selector2 ) + F32_F32_BETA_OP_NLT16F_MASK( c_use, k3, zmm16, 0, 0, 2, selector1, selector2 ) + F32_F32_BETA_OP_NLT16F_MASK( c_use, k4, zmm20, 0, 0, 3, selector1, selector2 ) + } + } + + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + BF16_F32_BIAS_LOAD(selector1, k1, 0); + BF16_F32_BIAS_LOAD(selector2, k2, 1); + BF16_F32_BIAS_LOAD(selector3, k3, 2); + BF16_F32_BIAS_LOAD(selector4, k4, 3); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k4, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + zmm12 = _mm512_add_ps( selector2, zmm12 ); + zmm16 = _mm512_add_ps( selector3, zmm16 ); + zmm20 = _mm512_add_ps( selector4, zmm20 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + __mmask16 bias_mask = _cvtu32_mask16( 0xFFFF ); + BF16_F32_BIAS_BCAST(selector1, bias_mask, 0); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + } + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + zmm12 = _mm512_add_ps( selector1, zmm12 ); + zmm16 = _mm512_add_ps( selector1, zmm16 ); + zmm20 = _mm512_add_ps( selector1, zmm20 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_ps(); + + zmm8 = _mm512_max_ps( selector1, zmm8 ); + zmm12 = _mm512_max_ps( selector1, zmm12 ); + zmm16 = _mm512_max_ps( selector1, zmm16 ); + zmm20 = _mm512_max_ps( selector1, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + RELU_SCALE_OP_F32_AVX512( zmm8 ) + RELU_SCALE_OP_F32_AVX512( zmm12 ) + RELU_SCALE_OP_F32_AVX512( zmm16 ) + RELU_SCALE_OP_F32_AVX512( zmm20 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + GELU_TANH_F32_AVX512( zmm8, r, r2, x, z, dn, x_tanh, q ) + GELU_TANH_F32_AVX512( zmm12, r, r2, x, z, dn, x_tanh, q ) + GELU_TANH_F32_AVX512( zmm16, r, r2, x, z, dn, x_tanh, q ) + GELU_TANH_F32_AVX512( zmm20, r, r2, x, z, dn, x_tanh, q ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, x_erf; + + GELU_ERF_F32_AVX512( zmm8, r, x, x_erf ) + GELU_ERF_F32_AVX512( zmm12, r, x, x_erf ) + GELU_ERF_F32_AVX512( zmm16, r, x, x_erf ) + GELU_ERF_F32_AVX512( zmm20, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_CLIP_6x64: + { + __m512 min = + _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = + _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + CLIP_F32_AVX512( zmm8, min, max ) + CLIP_F32_AVX512( zmm12, min, max ) + CLIP_F32_AVX512( zmm16, min, max ) + CLIP_F32_AVX512( zmm20, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_DOWNSCALE_6x64: + { + __m512 selector3 = _mm512_setzero_ps(); + __m512 selector4 = _mm512_setzero_ps(); + + __m512 zero_point0 = _mm512_setzero_ps(); + __m512 zero_point1 = _mm512_setzero_ps(); + __m512 zero_point2 = _mm512_setzero_ps(); + __m512 zero_point3 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_maskz_loadu_ps( k4, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k4, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + + // c[0, 0-15] + SCL_MULRND_F32(zmm8,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(zmm12,selector2,zero_point1); + + // c[0, 32-47] + SCL_MULRND_F32(zmm16,selector3,zero_point2); + + // c[0, 48-63] + SCL_MULRND_F32(zmm20,selector4,zero_point3); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + // Scale/zp len cannot be > 1, since original n = 1 for + // swapped m to be = 1. + + // c[0, 0-15] + SCL_MULRND_F32(zmm8,selector1,zero_point0); + + // c[0, 16-31] + SCL_MULRND_F32(zmm12,selector1,zero_point0); + + // c[0, 32-47] + SCL_MULRND_F32(zmm16,selector1,zero_point0); + + // c[0, 48-63] + SCL_MULRND_F32(zmm20,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_MATRIX_ADD_6x64: + { + __m512 selector3; + __m512 selector4; + + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + BF16_F32_MATRIX_ADD_LOAD + ( k1, selector1, 0, 0 ) + BF16_F32_MATRIX_ADD_LOAD + ( k2, selector2, 0, 1 ) + BF16_F32_MATRIX_ADD_LOAD + ( k3, selector3, 0, 2 ) + BF16_F32_MATRIX_ADD_LOAD + ( k4, selector4, 0, 3 ) + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + zmm12 = _mm512_add_ps( selector2, zmm12 ); + zmm16 = _mm512_add_ps( selector3, zmm16 ); + zmm20 = _mm512_add_ps( selector4, zmm20 ); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + F32_F32_MATRIX_ADD_LOAD + ( k1, selector1, 0, 0 ) + F32_F32_MATRIX_ADD_LOAD + ( k2, selector2, 0, 1 ) + F32_F32_MATRIX_ADD_LOAD + ( k3, selector3, 0, 2 ) + F32_F32_MATRIX_ADD_LOAD + ( k4, selector4, 0, 3 ) + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + zmm12 = _mm512_add_ps( selector2, zmm12 ); + zmm16 = _mm512_add_ps( selector3, zmm16 ); + zmm20 = _mm512_add_ps( selector4, zmm20 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_MATRIX_MUL_6x64: + { + __m512 selector3; + __m512 selector4; + + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + BF16_F32_MATRIX_MUL_LOAD + ( k1, selector1, 0, 0 ) + BF16_F32_MATRIX_MUL_LOAD + ( k2, selector2, 0, 1 ) + BF16_F32_MATRIX_MUL_LOAD + ( k3, selector3, 0, 2 ) + BF16_F32_MATRIX_MUL_LOAD + ( k4, selector4, 0, 3 ) + + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + zmm12 = _mm512_mul_ps( selector2, zmm12 ); + zmm16 = _mm512_mul_ps( selector3, zmm16 ); + zmm20 = _mm512_mul_ps( selector4, zmm20 ); + } + else + { + float* matptr = ( float* )post_ops_list_temp->op_args1; + + F32_F32_MATRIX_MUL_LOAD + ( k1, selector1, 0, 0 ) + F32_F32_MATRIX_MUL_LOAD + ( k2, selector2, 0, 1 ) + F32_F32_MATRIX_MUL_LOAD + ( k3, selector3, 0, 2 ) + F32_F32_MATRIX_MUL_LOAD + ( k4, selector4, 0, 3 ) + + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + zmm12 = _mm512_mul_ps( selector2, zmm12 ); + zmm16 = _mm512_mul_ps( selector3, zmm16 ); + zmm20 = _mm512_mul_ps( selector4, zmm20 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + SWISH_F32_AVX512_DEF( zmm8, selector1, al_in, r, r2, z, dn, ex_out ); + SWISH_F32_AVX512_DEF( zmm12, selector1, al_in, r, r2, z, dn, ex_out ); + SWISH_F32_AVX512_DEF( zmm16, selector1, al_in, r, r2, z, dn, ex_out ); + SWISH_F32_AVX512_DEF( zmm20, selector1, al_in, r, r2, z, dn, ex_out ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_6x64_DISABLE: + { + // Case where the output C matrix is bf16 (downscaled) + // and this is the final write for a given block within C. + if ( post_ops_attr.buf_downscale != NULL ) + { + _mm256_mask_storeu_epi16 + ( + ( bfloat16* )post_ops_attr.buf_downscale + + post_ops_attr.post_op_c_j + ( 0 * 16 ), + k1, (__m256i) _mm512_cvtneps_pbh( zmm8 ) + ); + + _mm256_mask_storeu_epi16 + ( + ( bfloat16* )post_ops_attr.buf_downscale + + post_ops_attr.post_op_c_j + ( 1 * 16 ), + k2, (__m256i) _mm512_cvtneps_pbh( zmm12 ) + ); + + _mm256_mask_storeu_epi16 + ( + ( bfloat16* )post_ops_attr.buf_downscale + + post_ops_attr.post_op_c_j + ( 2 * 16 ), + k3, (__m256i) _mm512_cvtneps_pbh( zmm16 ) + ); + + _mm256_mask_storeu_epi16 + ( + ( bfloat16* )post_ops_attr.buf_downscale + + post_ops_attr.post_op_c_j + ( 3 * 16 ), + k4, (__m256i) _mm512_cvtneps_pbh( zmm20 ) + ); + } + else + { + // Store the results. + _mm512_mask_storeu_ps( c_use + ( 0*16 ), k1, zmm8 ); + _mm512_mask_storeu_ps( c_use + ( 1*16 ), k2, zmm12 ); + _mm512_mask_storeu_ps( c_use + ( 2*16 ), k3, zmm16 ); + _mm512_mask_storeu_ps( c_use + ( 3*16 ), k4, zmm20 ); + } + } + + post_ops_attr.post_op_c_j += nr0; + + } // jr loop +} + +#endif // LPGEMM_BF16_JIT +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c new file mode 100644 index 0000000000..96ecdd3dac --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c @@ -0,0 +1,910 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_f32_kern_macros.h" + + +// Zero-out the given ZMM accumulator registers +#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \ + xmm0 = _mm_setzero_ps(); \ + xmm1 = _mm_setzero_ps(); \ + xmm2 = _mm_setzero_ps(); \ + xmm3 = _mm_setzero_ps(); + + +#define LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, zmm3, k1, paddr, stride ) \ + zmm0 = (__m512bh)_mm512_maskz_loadu_epi16( k1, paddr ); \ + zmm1 = (__m512bh)_mm512_maskz_loadu_epi16( k1, paddr + stride ); \ + zmm2 = (__m512bh)_mm512_maskz_loadu_epi16( k1, paddr + 2 * stride ); \ + zmm3 = (__m512bh)_mm512_maskz_loadu_epi16( k1, paddr + 3 * stride ); + +#define LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, paddr, stride ) \ + zmm0 = (__m512bh)_mm512_loadu_epi16( paddr ); \ + zmm1 = (__m512bh)_mm512_loadu_epi16( paddr + stride ); \ + zmm2 = (__m512bh)_mm512_loadu_epi16( paddr + 2 * stride ); \ + zmm3 = (__m512bh)_mm512_loadu_epi16( paddr + 3 * stride ); + + +#define LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3 ) \ + zmm8 = _mm512_dpbf16_ps( zmm8, zmm6, zmm0 ); \ + zmm9 = _mm512_dpbf16_ps( zmm9, zmm6, zmm1 ); \ + zmm10 = _mm512_dpbf16_ps( zmm10, zmm6, zmm2 ); \ + zmm11 = _mm512_dpbf16_ps( zmm11, zmm6, zmm3 ); + + +#define LPGEMV_ZMM2XMM(zmm0, zmm1, zmm2, zmm3, ymm0, ymm1, ymm2, ymm3, xmm0) \ + ymm0 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm0, 0x0), \ + _mm512_extractf32x8_ps(zmm0, 0x1)); \ + ymm1 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm1, 0x0), \ + _mm512_extractf32x8_ps(zmm1, 0x1)); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + ymm2 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm2, 0x0), \ + _mm512_extractf32x8_ps(zmm2, 0x1)); \ + ymm3 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm3, 0x0), \ + _mm512_extractf32x8_ps(zmm3, 0x1)); \ + ymm1 = _mm256_hadd_ps(ymm2, ymm3); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + xmm0 = _mm_add_ps(_mm256_extractf128_ps(ymm0, 0), _mm256_extractf128_ps(ymm0,1)); + +#ifdef LPGEMM_BF16_JIT +LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) +{} +#else +LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64, + &&POST_OPS_MATRIX_MUL_6x64 + }; + + // Strides are updated based on matrix packing/reordering. + const bfloat16 *a_use = NULL; + const bfloat16 *b_use = NULL; + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for ( dim_t ir = 0; ir < m0; ir += MR ) + { + dim_t mr0 = bli_min( ( m0 - ir ), MR ); + dim_t k_iter = k/32; + dim_t k_rem = k & 0x1F; + + //Create load mask for k fringe + __mmask32 k1 = 0xFFFFFFFF; + if( k_rem ) + { + k1 = ( 0xFFFFFFFF >> ( 32 - k_rem ) ); + } + + // Create store mask for C for mr fringe + __mmask16 k2 = 0xFFFF; + if ( mr0 < MR ) + { + k2 = ( 0xFFFF >> ( MR - mr0 ) ); + } + + __m512bh zmm0, zmm1, zmm2, zmm3; + __m512bh zmm6; + __m512 zmm8, zmm9, zmm10, zmm11; + __m512 zmm12, zmm13, zmm14, zmm15; + __m512 zmm16, zmm17, zmm18, zmm19; + __m512 zmm20, zmm21, zmm22, zmm23; + __m512bh zmm24, zmm25, zmm26, zmm27; + __m512bh zmm28, zmm29, zmm30, zmm31; + + __m256 ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6; + __m128 xmm0, xmm1, xmm2, xmm3; + + /* zero the accumulator registers */ + ZERO_ACC_ZMM_4_REG( zmm8, zmm9, zmm10, zmm11 ) + ZERO_ACC_ZMM_4_REG( zmm12, zmm13, zmm14, zmm15 ) + ZERO_ACC_ZMM_4_REG( zmm16, zmm17, zmm18, zmm19 ) + ZERO_ACC_ZMM_4_REG( zmm20, zmm21, zmm22, zmm23 ) + ZERO_ACC_XMM_4_REG( xmm0, xmm1, xmm2, xmm3 ) + //update pointers + a_use = a + ir * rs_a; + b_use = b; + c_use = c + ir * rs_c; + + if( mr0 == MR ) + { + //Dot product kernel + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = ( __m512bh )_mm512_loadu_epi16( b_use ); + b_use += 32; + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + a_use += ( 4 * rs_a ); + + // Load 4x32 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x32 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS( zmm28, zmm29, zmm30, + zmm31, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x32 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 32; + + + } // kloop + if( k_rem ) + { + zmm6 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, b_use ); + + //Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x32 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x32 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm28, zmm29, zmm30, + zmm31, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x32 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 32; + + } + + //Add the registers horizantally to get one + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + //compose outputs into one zmm to perform post-ops + zmm8 = _mm512_insertf32x4( zmm8, xmm0, 0 ); + zmm8 = _mm512_insertf32x4( zmm8, xmm1, 1 ); + zmm8 = _mm512_insertf32x4( zmm8, xmm2, 2 ); + zmm8 = _mm512_insertf32x4( zmm8, xmm3, 3 ); + } + else + { + //Handle fringe cases when mr0 < MR + const bfloat16 *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + // Dot product for mfringe 8 + if ( mr0_use >= 8 ) + { + // Dot product kernel for mr0 == 8 + for( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_loadu_epi16( b_use ); + // move b pointer to next 32 elements + b_use += 32; + + // Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x32 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use -= ( 4 * rs_a ); + + //Perform FMA on two 4x16 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + a_use += 32; + } + + if ( k_rem ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, b_use ); + + // Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + } + + //update pointers + mr0_use -= 8; + a_use = a_use_fringe + 8 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 8 zmm registers and get output into 2 xmm registers + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + + //insert xmm outputs into final output zmm8 reg + zmm8 = _mm512_insertf32x4( zmm8, xmm0, 0 ); + zmm8 = _mm512_insertf32x4( zmm8, xmm1, 1 ); + regidx = 2; + } + + // Dot product for mfringe 4 + if ( mr0_use >= 4 ) + { + // Dot product kernel for mr0 == 8 + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_loadu_epi16( b_use ); + + // move b pointer to next 32 elements + b_use += 32; + + // Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + // Perform FMA on 4x32 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 32; + } + + if ( k_rem ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, b_use ); + + // Load 4x32 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + + //insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) zmm8 = _mm512_insertf32x4( zmm8, xmm2, 0 ); + else zmm8 = _mm512_insertf32x4( zmm8, xmm2, 2 ); + regidx++; + } + + // Dot product for <= 3 + if ( mr0_use ) + { + // Dot product for m = 2 + if ( mr0_use >= 2 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_loadu_epi16( b_use ); + + // Load 2x32 elements from row0-row1 of A + zmm0 = ( __m512bh )_mm512_loadu_epi16( a_use ); + zmm1 = ( __m512bh )_mm512_loadu_epi16( a_use + rs_a ); + zmm20 = _mm512_dpbf16_ps( zmm20, zmm6, zmm0 ); + zmm21 = _mm512_dpbf16_ps( zmm21, zmm6, zmm1 ); + + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + if ( k_rem ) + { + // Load 0-31 in b[k+0 - k+31] + zmm6 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, b_use ); + zmm0 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, a_use ); + zmm1 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, a_use + rs_a ); + zmm20 = _mm512_dpbf16_ps( zmm20, zmm6, zmm0 ); + zmm21 = _mm512_dpbf16_ps( zmm21, zmm6, zmm1 ); + } + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = b; + } + + // Dot product for m = 2 + if ( mr0_use == 1 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-31 in b[k+0 - k+15] + zmm6 = ( __m512bh )_mm512_loadu_epi16( b_use ); + zmm0 = ( __m512bh )_mm512_loadu_epi16( a_use ); + zmm22 = _mm512_dpbf16_ps( zmm22, zmm6, zmm0 ); + b_use += 32; // move b pointer to next 32 elements + a_use += 32; + } + + if ( k_rem ) + { + zmm6 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, b_use ); + zmm0 = ( __m512bh )_mm512_maskz_loadu_epi16( k1, a_use ); + zmm22 = _mm512_dpbf16_ps( zmm22, zmm6, zmm0 ); + } + // When only fringe 1, update the registers to store in order + if ( !( mr0 & 0x2 ) ) zmm20 = zmm22; + } + + // Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + // insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) + { + zmm8 = _mm512_insertf32x4( zmm8, xmm3, 0 ); + } + else if( regidx == 1 ) + { + zmm8 = _mm512_insertf32x4( zmm8, xmm3, 1 ); + } + else if ( regidx == 2 ) + { + zmm8 = _mm512_insertf32x4( zmm8, xmm3, 2 ); + } + else + { + zmm8 = _mm512_insertf32x4( zmm8, xmm3, 3 ); + } + } + } + + //Scale accumulated output with alpha + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + + if ( beta != 0 ) + { + + // For the downscaled api (C-bf16), the output C matrix values + // needs to be upscaled to float to be used for beta scale. + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + BF16_F32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0, + selector1, selector2 ) + } + else + { + bfloat16 ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( ( bfloat16* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = (__m512)( _mm512_sllv_epi32( _mm512_cvtepi16_epi32 + ( (__m256i)_mm256_loadu_epi16( ctemp ) ), + _mm512_set1_epi32 (16) ) ); + F32_BETA_FMA(zmm8,selector1,selector2) + } + } + else + { + if( rs_c == 1 ) + { + F32_F32_BETA_OP_NLT16F_MASK( c_use, k2, zmm8, 0, 0, 0, + selector1, selector2 ) + } + else + { + float ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = c_use[i*rs_c]; + } + + selector1 = _mm512_loadu_ps( ctemp ); + F32_BETA_FMA( zmm8, selector1, selector2 ); + } + } + } + + // Post Ops + lpgemm_post_op *post_ops_list_temp = post_op; + + post_ops_attr.is_last_k = TRUE; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + selector1 = + (__m512)( _mm512_sllv_epi32 + ( + _mm512_cvtepi16_epi32 + ( + _mm256_maskz_set1_epi16 + ( + _cvtu32_mask16( 0xFFFF ), + *( ( bfloat16* )post_ops_list_temp->op_args1 ) + ) + ), _mm512_set1_epi32( 16 ) + ) + ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1) ); + } + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + else + { + if ( post_ops_attr.c_stor_type == BF16 ) + { + selector1 = + (__m512)( _mm512_sllv_epi32 + ( + _mm512_cvtepi16_epi32 + ( + _mm256_maskz_loadu_epi16 + ( + k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + ) + ), _mm512_set1_epi32( 16 ) + ) + ); + } + else + { + selector1 = + _mm512_maskz_loadu_ps( k2, + (float*)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i ); + } + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_ps(); + + zmm8 = _mm512_max_ps( selector1, zmm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512( zmm8 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + GELU_TANH_F32_AVX512( zmm8, r, r2, x, z, dn, x_tanh, q ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, x_erf; + + GELU_ERF_F32_AVX512( zmm8, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64: + { + __m512 min = + _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = + _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + CLIP_F32_AVX512( zmm8, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6x64: + { + __m512 zero_point0 = _mm512_setzero_ps(); + + __mmask16 zp_mask = _cvtu32_mask16( 0xFFFF ); + + // Need to account for row vs column major swaps. For scalars + // scale and zero point, no implications. + // Even though different registers are used for scalar in column + // and row major downscale path, all those registers will contain + // the same value. + if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // bf16 zero point value (scalar or vector). + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, + *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); + } + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + // Scale/zp len cannot be > 1, since orignal n = 1. + SCL_MULRND_F32(zmm8,selector1,zero_point0); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the scale as well as zp array will + // be accessed by the ic index, and each scale/zp element + // corresponds to an entire row of the transposed output array, + // instead of an entire column. + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_i + 0 ); + } + + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, + ( ( bfloat16* )post_ops_list_temp->op_args1 ) + + post_ops_attr.post_op_c_i + 0 ) ); + } + SCL_MULRND_F32(zmm8,selector1,zero_point0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + BF16_F32_MATRIX_ADD_LOAD(k2,selector1,0,0) + + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + else + { + bfloat16 ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_loadu_epi16 \ + ( \ + k2 , ctemp \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + } + else + { + + float* matptr = ( float* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + F32_F32_MATRIX_ADD_LOAD(k2,selector1,0,0) + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + else + { + float ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_maskz_loadu_ps( k2, ctemp ); + zmm8 = _mm512_add_ps( selector1, zmm8 ); + } + + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_MUL_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + + if ( post_ops_attr.c_stor_type == BF16 ) + { + bfloat16* matptr = ( bfloat16* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + BF16_F32_MATRIX_MUL_LOAD(k2,selector1,0,0) + + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + } + else + { + bfloat16 ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = (__m512)( _mm512_sllv_epi32 \ + ( \ + _mm512_cvtepi16_epi32 \ + ( \ + _mm256_maskz_loadu_epi16 \ + ( \ + k2 , ctemp \ + ) \ + ), _mm512_set1_epi32( 16 ) \ + ) \ + ); \ + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + } + } + else + { + + float* matptr = ( float* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + F32_F32_MATRIX_MUL_LOAD(k2,selector1,0,0) + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + } + else + { + float ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_maskz_loadu_ps( k2, ctemp ); + zmm8 = _mm512_mul_ps( selector1, zmm8 ); + } + + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + SWISH_F32_AVX512_DEF( zmm8, selector1, al_in, + r, r2, z, dn, ex_out ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64_DISABLE: + { + // Case where the output C matrix is bf16 (downscaled) and + // this is the final write for a given block within C. + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + _mm256_mask_storeu_epi16 + ( + ( bfloat16* )post_ops_attr.buf_downscale + + post_ops_attr.post_op_c_i, + k2, (__m256i) _mm512_cvtneps_pbh( zmm8 ) + ); + } + else + { + bfloat16 ctemp[16]; + _mm256_mask_storeu_epi16 + ( + ctemp, + k2, (__m256i) _mm512_cvtneps_pbh( zmm8 ) + ); + for (dim_t i = 0; i < mr0; i++) + { + *( ( bfloat16* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + } + else + { + if(rs_c == 1) + { + _mm512_mask_storeu_ps(c_use, k2, zmm8); + } + else + { + // Store ZMM8 into ctemp buffer and store back + // element by element into output buffer at strides + float ctemp[16]; + _mm512_mask_storeu_ps(ctemp, k2, zmm8); + for (dim_t i = 0; i < mr0; i++) + { + c_use[i * rs_c] = ctemp[i]; + } + } + } + post_ops_attr.post_op_c_i += MR; + } + } +} +#endif // LPGEMM_BF16_JIT +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_fringe_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_fringe_f32_avx512.c new file mode 100644 index 0000000000..0925d38542 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_fringe_f32_avx512.c @@ -0,0 +1,2851 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_5x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_OPS_DISABLE, + &&POST_OPS_BIAS_5x64_OPS, + &&POST_OPS_RELU_5x64_OPS, + &&POST_OPS_RELU_SCALE_5x64_OPS, + &&POST_OPS_GELU_TANH_5x64_OPS, + &&POST_OPS_GELU_ERF_5x64_OPS, + &&POST_OPS_CLIP_5x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x64_OPS, + &&POST_OPS_SWISH_5x64_OPS, + &&POST_OPS_MATRIX_MUL_5x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm12 = _mm512_setzero_ps(); + __m512 zmm13 = _mm512_setzero_ps(); + __m512 zmm14 = _mm512_setzero_ps(); + __m512 zmm15 = _mm512_setzero_ps(); + + __m512 zmm16 = _mm512_setzero_ps(); + __m512 zmm17 = _mm512_setzero_ps(); + __m512 zmm18 = _mm512_setzero_ps(); + __m512 zmm19 = _mm512_setzero_ps(); + + __m512 zmm20 = _mm512_setzero_ps(); + __m512 zmm21 = _mm512_setzero_ps(); + __m512 zmm22 = _mm512_setzero_ps(); + __m512 zmm23 = _mm512_setzero_ps(); + + __m512 zmm24 = _mm512_setzero_ps(); + __m512 zmm25 = _mm512_setzero_ps(); + __m512 zmm26 = _mm512_setzero_ps(); + __m512 zmm27 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps( k0, a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm9 = _mm512_maskz_loadu_ps( k1, a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm10 = _mm512_maskz_loadu_ps( k2, a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm11 = _mm512_maskz_loadu_ps( k3, a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 2ndx64 block. + zmm12 = _mm512_maskz_loadu_ps( k0, a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm13 = _mm512_maskz_loadu_ps( k1, a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm14 = _mm512_maskz_loadu_ps( k2, a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm15 = _mm512_maskz_loadu_ps( k3, a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 3rdx64 block. + zmm16 = _mm512_maskz_loadu_ps( k0, a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm17 = _mm512_maskz_loadu_ps( k1, a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm18 = _mm512_maskz_loadu_ps( k2, a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm19 = _mm512_maskz_loadu_ps( k3, a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 4thx64 block. + zmm20 = _mm512_maskz_loadu_ps( k0, a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm21 = _mm512_maskz_loadu_ps( k1, a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm22 = _mm512_maskz_loadu_ps( k2, a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm23 = _mm512_maskz_loadu_ps( k3, a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 5thx64 block. + zmm24 = _mm512_maskz_loadu_ps( k0, a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm25 = _mm512_maskz_loadu_ps( k1, a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm26 = _mm512_maskz_loadu_ps( k2, a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm27 = _mm512_maskz_loadu_ps( k3, a + ( rs_a * ( 4 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_5x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 =_mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 =_mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 =_mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 =_mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm1, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm3, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm4, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm1, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm2, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm4, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm1, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm2, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm3, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_add_ps( zmm1, zmm24 ); + + // c[4, 16-31] + zmm25 = _mm512_add_ps( zmm2, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_add_ps( zmm3, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_add_ps( zmm4, zmm27 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + zmm3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + zmm4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm2, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm2, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm2, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm3, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm3, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm3, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm4, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm4, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm4, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_add_ps( selector5, zmm24 ); + + // c[4, 16-31] + zmm25 = _mm512_add_ps( selector5, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_add_ps( selector5, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_add_ps( selector5, zmm27 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_max_ps( zmm1, zmm12 ); + + // c[1,16-31] + zmm13 = _mm512_max_ps( zmm1, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_max_ps( zmm1, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_max_ps( zmm1, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_max_ps( zmm1, zmm16 ); + + // c[2,16-31] + zmm17 = _mm512_max_ps( zmm1, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_max_ps( zmm1, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_max_ps( zmm1, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_max_ps( zmm1, zmm20 ); + + // c[3,16-31] + zmm21 = _mm512_max_ps( zmm1, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_max_ps( zmm1, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_max_ps( zmm1, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_max_ps( zmm1, zmm24 ); + + // c[4,16-31] + zmm25 = _mm512_max_ps( zmm1, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_max_ps( zmm1, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_max_ps( zmm1, zmm27 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + // c[1, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm12) + + // c[1, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm13) + + // c[1, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm14) + + // c[1, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm15) + + // c[2, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm16) + + // c[2, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm17) + + // c[2, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm18) + + // c[2, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm19) + + // c[3, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm20) + + // c[3, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm21) + + // c[3, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm22) + + // c[3, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm23) + + // c[4, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm24) + + // c[4, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm25) + + // c[4, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm26) + + // c[4, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm27) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_5x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32S_AVX512(zmm12, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32S_AVX512(zmm13, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32S_AVX512(zmm14, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32S_AVX512(zmm15, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32S_AVX512(zmm16, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32S_AVX512(zmm17, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32S_AVX512(zmm18, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32S_AVX512(zmm19, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32S_AVX512(zmm20, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32S_AVX512(zmm21, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32S_AVX512(zmm22, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32S_AVX512(zmm23, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32S_AVX512(zmm24, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32S_AVX512(zmm25, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32S_AVX512(zmm26, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32S_AVX512(zmm27, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_5x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32S_AVX512(zmm12, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32S_AVX512(zmm13, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32S_AVX512(zmm14, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32S_AVX512(zmm15, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32S_AVX512(zmm16, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32S_AVX512(zmm17, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32S_AVX512(zmm18, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32S_AVX512(zmm19, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32S_AVX512(zmm20, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32S_AVX512(zmm21, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32S_AVX512(zmm22, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32S_AVX512(zmm23, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32S_AVX512(zmm24, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32S_AVX512(zmm25, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32S_AVX512(zmm26, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32S_AVX512(zmm27, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_5x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + // c[1, 0-15] + CLIP_F32S_AVX512(zmm12, min, max) + + // c[1, 16-31] + CLIP_F32S_AVX512(zmm13, min, max) + + // c[1, 32-47] + CLIP_F32S_AVX512(zmm14, min, max) + + // c[1, 48-63] + CLIP_F32S_AVX512(zmm15, min, max) + + // c[2, 0-15] + CLIP_F32S_AVX512(zmm16, min, max) + + // c[2, 16-31] + CLIP_F32S_AVX512(zmm17, min, max) + + // c[2, 32-47] + CLIP_F32S_AVX512(zmm18, min, max) + + // c[2, 48-63] + CLIP_F32S_AVX512(zmm19, min, max) + + // c[3, 0-15] + CLIP_F32S_AVX512(zmm20, min, max) + + // c[3, 16-31] + CLIP_F32S_AVX512(zmm21, min, max) + + // c[3, 32-47] + CLIP_F32S_AVX512(zmm22, min, max) + + // c[3, 48-63] + CLIP_F32S_AVX512(zmm23, min, max) + + // c[4, 0-15] + CLIP_F32S_AVX512(zmm24, min, max) + + // c[4, 16-31] + CLIP_F32S_AVX512(zmm25, min, max) + + // c[4, 32-47] + CLIP_F32S_AVX512(zmm26, min, max) + + // c[4, 48-63] + CLIP_F32S_AVX512(zmm27, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,24,25,26,27,zmm1,zmm2,zmm3,zmm4,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,24,25,26,27,zmm1,zmm2,zmm3,zmm4,4); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(zmm27, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_OPS_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm12 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm13 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm14 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm15 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm16 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm17 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm18 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm19 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm20 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm21 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm22 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm23 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm24 ); + // c[4,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm25 ); + // c[4,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm26 ); + // c[4,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 4 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm27 ); + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_4x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_OPS_DISABLE, + &&POST_OPS_BIAS_4x64_OPS, + &&POST_OPS_RELU_4x64_OPS, + &&POST_OPS_RELU_SCALE_4x64_OPS, + &&POST_OPS_GELU_TANH_4x64_OPS, + &&POST_OPS_GELU_ERF_4x64_OPS, + &&POST_OPS_CLIP_4x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x64_OPS, + &&POST_OPS_SWISH_4x64_OPS, + &&POST_OPS_MATRIX_MUL_4x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm12 = _mm512_setzero_ps(); + __m512 zmm13 = _mm512_setzero_ps(); + __m512 zmm14 = _mm512_setzero_ps(); + __m512 zmm15 = _mm512_setzero_ps(); + + __m512 zmm16 = _mm512_setzero_ps(); + __m512 zmm17 = _mm512_setzero_ps(); + __m512 zmm18 = _mm512_setzero_ps(); + __m512 zmm19 = _mm512_setzero_ps(); + + __m512 zmm20 = _mm512_setzero_ps(); + __m512 zmm21 = _mm512_setzero_ps(); + __m512 zmm22 = _mm512_setzero_ps(); + __m512 zmm23 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm9 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm10 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm11 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 2ndx64 block. + zmm12 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm13 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm14 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm15 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 3rdx64 block. + zmm16 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm17 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm18 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm19 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 4thx64 block. + zmm20 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm21 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm22 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm23 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 3 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_4x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm1, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm3, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm4, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm1, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm2, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm4, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm1, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm2, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm3, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + zmm3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + zmm4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm2, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm2, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm2, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm3, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm3, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm3, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm4, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm4, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm4, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_max_ps( zmm1, zmm12 ); + + // c[1,16-31] + zmm13 = _mm512_max_ps( zmm1, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_max_ps( zmm1, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_max_ps( zmm1, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_max_ps( zmm1, zmm16 ); + + // c[2,16-31] + zmm17 = _mm512_max_ps( zmm1, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_max_ps( zmm1, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_max_ps( zmm1, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_max_ps( zmm1, zmm20 ); + + // c[3,16-31] + zmm21 = _mm512_max_ps( zmm1, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_max_ps( zmm1, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_max_ps( zmm1, zmm23 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + // c[1, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm12) + + // c[1, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm13) + + // c[1, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm14) + + // c[1, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm15) + + // c[2, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm16) + + // c[2, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm17) + + // c[2, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm18) + + // c[2, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm19) + + // c[3, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm20) + + // c[3, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm21) + + // c[3, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm22) + + // c[3, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm23) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_4x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32S_AVX512(zmm12, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32S_AVX512(zmm13, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32S_AVX512(zmm14, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32S_AVX512(zmm15, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32S_AVX512(zmm16, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32S_AVX512(zmm17, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32S_AVX512(zmm18, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32S_AVX512(zmm19, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32S_AVX512(zmm20, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32S_AVX512(zmm21, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32S_AVX512(zmm22, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32S_AVX512(zmm23, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_4x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32S_AVX512(zmm12, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32S_AVX512(zmm13, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32S_AVX512(zmm14, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32S_AVX512(zmm15, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32S_AVX512(zmm16, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32S_AVX512(zmm17, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32S_AVX512(zmm18, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32S_AVX512(zmm19, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32S_AVX512(zmm20, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32S_AVX512(zmm21, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32S_AVX512(zmm22, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32S_AVX512(zmm23, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_4x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + // c[1, 0-15] + CLIP_F32S_AVX512(zmm12, min, max) + + // c[1, 16-31] + CLIP_F32S_AVX512(zmm13, min, max) + + // c[1, 32-47] + CLIP_F32S_AVX512(zmm14, min, max) + + // c[1, 48-63] + CLIP_F32S_AVX512(zmm15, min, max) + + // c[2, 0-15] + CLIP_F32S_AVX512(zmm16, min, max) + + // c[2, 16-31] + CLIP_F32S_AVX512(zmm17, min, max) + + // c[2, 32-47] + CLIP_F32S_AVX512(zmm18, min, max) + + // c[2, 48-63] + CLIP_F32S_AVX512(zmm19, min, max) + + // c[3, 0-15] + CLIP_F32S_AVX512(zmm20, min, max) + + // c[3, 16-31] + CLIP_F32S_AVX512(zmm21, min, max) + + // c[3, 32-47] + CLIP_F32S_AVX512(zmm22, min, max) + + // c[3, 48-63] + CLIP_F32S_AVX512(zmm23, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x64_OPS_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm12 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm13 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm14 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm15 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm16 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm17 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm18 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm19 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm20 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm21 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm22 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm23 ); + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_3x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_OPS_DISABLE, + &&POST_OPS_BIAS_3x64_OPS, + &&POST_OPS_RELU_3x64_OPS, + &&POST_OPS_RELU_SCALE_3x64_OPS, + &&POST_OPS_GELU_TANH_3x64_OPS, + &&POST_OPS_GELU_ERF_3x64_OPS, + &&POST_OPS_CLIP_3x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x64_OPS, + &&POST_OPS_SWISH_3x64_OPS, + &&POST_OPS_MATRIX_MUL_3x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm12 = _mm512_setzero_ps(); + __m512 zmm13 = _mm512_setzero_ps(); + __m512 zmm14 = _mm512_setzero_ps(); + __m512 zmm15 = _mm512_setzero_ps(); + + __m512 zmm16 = _mm512_setzero_ps(); + __m512 zmm17 = _mm512_setzero_ps(); + __m512 zmm18 = _mm512_setzero_ps(); + __m512 zmm19 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm9 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ) ; + zmm10 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ) ; + zmm11 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ) ; + + // 2ndx64 block. + zmm12 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm13 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ) ; + zmm14 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ) ; + zmm15 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ) ; + + // 3rdx64 block. + zmm16 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm17 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 16 ) ) ) ; + zmm18 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 32 ) ) ) ; + zmm19 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 2 ) ) + ( cs_a * ( jr + 48 ) ) ) ; + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_3x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm1, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm3, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm4, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm1, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm2, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm4, zmm19 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + zmm3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm2, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm2, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm2, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm3, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm3, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm3, zmm19 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_max_ps( zmm1, zmm12 ); + + // c[1,16-31] + zmm13 = _mm512_max_ps( zmm1, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_max_ps( zmm1, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_max_ps( zmm1, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_max_ps( zmm1, zmm16 ); + + // c[2,16-31] + zmm17 = _mm512_max_ps( zmm1, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_max_ps( zmm1, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_max_ps( zmm1, zmm19 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + // c[1, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm12) + + // c[1, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm13) + + // c[1, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm14) + + // c[1, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm15) + + // c[2, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm16) + + // c[2, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm17) + + // c[2, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm18) + + // c[2, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm19) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_3x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32S_AVX512(zmm12, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32S_AVX512(zmm13, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32S_AVX512(zmm14, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32S_AVX512(zmm15, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32S_AVX512(zmm16, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32S_AVX512(zmm17, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32S_AVX512(zmm18, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32S_AVX512(zmm19, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_3x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32S_AVX512(zmm12, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32S_AVX512(zmm13, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32S_AVX512(zmm14, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32S_AVX512(zmm15, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32S_AVX512(zmm16, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32S_AVX512(zmm17, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32S_AVX512(zmm18, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32S_AVX512(zmm19, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_3x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + // c[1, 0-15] + CLIP_F32S_AVX512(zmm12, min, max) + + // c[1, 16-31] + CLIP_F32S_AVX512(zmm13, min, max) + + // c[1, 32-47] + CLIP_F32S_AVX512(zmm14, min, max) + + // c[1, 48-63] + CLIP_F32S_AVX512(zmm15, min, max) + + // c[2, 0-15] + CLIP_F32S_AVX512(zmm16, min, max) + + // c[2, 16-31] + CLIP_F32S_AVX512(zmm17, min, max) + + // c[2, 32-47] + CLIP_F32S_AVX512(zmm18, min, max) + + // c[2, 48-63] + CLIP_F32S_AVX512(zmm19, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_OPS_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm12 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm13 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm14 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm15 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm16 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm17 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm18 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm19 ); + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_2x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_OPS_DISABLE, + &&POST_OPS_BIAS_2x64_OPS, + &&POST_OPS_RELU_2x64_OPS, + &&POST_OPS_RELU_SCALE_2x64_OPS, + &&POST_OPS_GELU_TANH_2x64_OPS, + &&POST_OPS_GELU_ERF_2x64_OPS, + &&POST_OPS_CLIP_2x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x64_OPS, + &&POST_OPS_SWISH_2x64_OPS, + &&POST_OPS_MATRIX_MUL_2x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm12 = _mm512_setzero_ps(); + __m512 zmm13 = _mm512_setzero_ps(); + __m512 zmm14 = _mm512_setzero_ps(); + __m512 zmm15 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) )); + zmm9 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm10 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm11 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // 2ndx64 block. + zmm12 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 0 ) )); + zmm13 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm14 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm15 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 1 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_2x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm1, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm3, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm4, zmm15 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm2, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm2, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm2, zmm15 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_max_ps( zmm1, zmm12 ); + + // c[1,16-31] + zmm13 = _mm512_max_ps( zmm1, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_max_ps( zmm1, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_max_ps( zmm1, zmm15 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + // c[1, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm12) + + // c[1, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm13) + + // c[1, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm14) + + // c[1, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm15) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_2x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32S_AVX512(zmm12, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32S_AVX512(zmm13, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32S_AVX512(zmm14, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32S_AVX512(zmm15, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_2x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32S_AVX512(zmm12, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32S_AVX512(zmm13, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32S_AVX512(zmm14, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32S_AVX512(zmm15, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_2x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + // c[1, 0-15] + CLIP_F32S_AVX512(zmm12, min, max) + + // c[1, 16-31] + CLIP_F32S_AVX512(zmm13, min, max) + + // c[1, 32-47] + CLIP_F32S_AVX512(zmm14, min, max) + + // c[1, 48-63] + CLIP_F32S_AVX512(zmm15, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_OPS_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm12 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm13 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm14 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm15 ); + + post_ops_attr.post_op_c_j += NR_L; + } +} + +LPGEMM_ELTWISE_OPS_M_FRINGE_KERNEL(float,float,f32of32_1x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_OPS_DISABLE, + &&POST_OPS_BIAS_1x64_OPS, + &&POST_OPS_RELU_1x64_OPS, + &&POST_OPS_RELU_SCALE_1x64_OPS, + &&POST_OPS_GELU_TANH_1x64_OPS, + &&POST_OPS_GELU_ERF_1x64_OPS, + &&POST_OPS_CLIP_1x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x64_OPS, + &&POST_OPS_SWISH_1x64_OPS, + &&POST_OPS_MATRIX_MUL_1x64_OPS + }; + dim_t NR = 64; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps( k0, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 0 ) ) ); + zmm9 = _mm512_maskz_loadu_ps( k1, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 16 ) ) ); + zmm10 = _mm512_maskz_loadu_ps( k2, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 32 ) ) ); + zmm11 = _mm512_maskz_loadu_ps( k3, \ + a + ( rs_a * ( 0 ) ) + ( cs_a * ( jr + 48 ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_1x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 = + _mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_1x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_1x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_1x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_OPS_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + post_ops_attr.post_op_c_j += NR_L; + } +} + +#endif diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_m_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_m_kernel_f32_avx512.c new file mode 100644 index 0000000000..7ad8f17096 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_eltwise_ops_m_kernel_f32_avx512.c @@ -0,0 +1,1081 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +LPGEMM_ELTWISE_OPS_KERNEL(float,float,f32of32_6x64) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_OPS_DISABLE, + &&POST_OPS_BIAS_6x64_OPS, + &&POST_OPS_RELU_6x64_OPS, + &&POST_OPS_RELU_SCALE_6x64_OPS, + &&POST_OPS_GELU_TANH_6x64_OPS, + &&POST_OPS_GELU_ERF_6x64_OPS, + &&POST_OPS_CLIP_6x64_OPS, + NULL,// Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x64_OPS, + &&POST_OPS_SWISH_6x64_OPS, + &&POST_OPS_MATRIX_MUL_6x64_OPS + }; + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // Registers to use for accumulating C. + __m512 zmm8 = _mm512_setzero_ps(); + __m512 zmm9 = _mm512_setzero_ps(); + __m512 zmm10 = _mm512_setzero_ps(); + __m512 zmm11 = _mm512_setzero_ps(); + + __m512 zmm12 = _mm512_setzero_ps(); + __m512 zmm13 = _mm512_setzero_ps(); + __m512 zmm14 = _mm512_setzero_ps(); + __m512 zmm15 = _mm512_setzero_ps(); + + __m512 zmm16 = _mm512_setzero_ps(); + __m512 zmm17 = _mm512_setzero_ps(); + __m512 zmm18 = _mm512_setzero_ps(); + __m512 zmm19 = _mm512_setzero_ps(); + + __m512 zmm20 = _mm512_setzero_ps(); + __m512 zmm21 = _mm512_setzero_ps(); + __m512 zmm22 = _mm512_setzero_ps(); + __m512 zmm23 = _mm512_setzero_ps(); + + __m512 zmm24 = _mm512_setzero_ps(); + __m512 zmm25 = _mm512_setzero_ps(); + __m512 zmm26 = _mm512_setzero_ps(); + __m512 zmm27 = _mm512_setzero_ps(); + + __m512 zmm28 = _mm512_setzero_ps(); + __m512 zmm29 = _mm512_setzero_ps(); + __m512 zmm30 = _mm512_setzero_ps(); + __m512 zmm31 = _mm512_setzero_ps(); + + __m512 zmm1 = _mm512_setzero_ps(); + __m512 zmm2 = _mm512_setzero_ps(); + __m512 zmm3 = _mm512_setzero_ps(); + __m512 zmm4 = _mm512_setzero_ps(); + + uint64_t orig_post_op_c_j = post_ops_attr.post_op_c_j; + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + __mmask16 k0 = 0xFFFF, k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF; + + dim_t NR_L = NR; + for( dim_t jr = 0; jr < n0; jr += NR_L ) + { + dim_t n_left = n0 - jr; + NR_L = bli_min( NR_L, ( n_left >> 4 ) << 4 ); + if( NR_L == 0 ) { NR_L = 16; } + + dim_t nr0 = bli_min( n0 - jr, NR_L ); + if( nr0 == 64 ) + { + // all masks are already set. + // Nothing to modify. + } + else if( nr0 == 48 ) + { + k3 = 0x0; + } + else if( nr0 == 32 ) + { + k2 = k3 = 0x0; + } + else if( nr0 == 16 ) + { + k1 = k2 = k3 = 0; + } + else if( nr0 < 16 ) + { + k0 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k1 = k2 = k3 = 0; + } + + // 1stx64 block. + zmm8 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 0 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm9 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 0 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm10 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 0 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm11 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 0 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // 2ndx64 block. + zmm12 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 1 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm13 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 1 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm14 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 1 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm15 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 1 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // 3rdx64 block. + zmm16 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 2 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm17 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 2 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm18 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 2 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm19 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 2 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // 4thx64 block. + zmm20 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 3 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm21 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 3 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm22 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 3 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm23 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 3 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // 5thx64 block. + zmm24 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 4 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm25 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 4 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm26 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 4 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm27 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 4 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // 6thx64 block. + zmm28 = _mm512_maskz_loadu_ps(k0, a + ( rs_a * ( ir + 5 ) ) + + ( cs_a * ( jr + 0 ) ) ); + zmm29 = _mm512_maskz_loadu_ps(k1, a + ( rs_a * ( ir + 5 ) ) + + ( cs_a * ( jr + 16 ) ) ); + zmm30 = _mm512_maskz_loadu_ps(k2, a + ( rs_a * ( ir + 5 ) ) + + ( cs_a * ( jr + 32 ) ) ); + zmm31 = _mm512_maskz_loadu_ps(k3, a + ( rs_a * ( ir + 5 ) ) + + ( cs_a * ( jr + 48 ) ) ); + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + +POST_OPS_BIAS_6x64_OPS: + { + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + zmm1 =_mm512_maskz_loadu_ps( k0, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + zmm2 = + _mm512_maskz_loadu_ps( k1, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + zmm3 = + _mm512_maskz_loadu_ps( k2, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + zmm4 = + _mm512_maskz_loadu_ps( k3, + ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm2, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm3, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm4, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm1, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm3, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm4, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm1, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm2, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm4, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm1, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm2, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm3, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_add_ps( zmm1, zmm24 ); + + // c[4, 16-31] + zmm25 = _mm512_add_ps( zmm2, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_add_ps( zmm3, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_add_ps( zmm4, zmm27 ); + + // c[5,0-15] + zmm28 = _mm512_add_ps( zmm1, zmm28 ); + + // c[5, 16-31] + zmm29 = _mm512_add_ps( zmm2, zmm29 ); + + // c[5,32-47] + zmm30 = _mm512_add_ps( zmm3, zmm30 ); + + // c[5,48-63] + zmm31 = _mm512_add_ps( zmm4, zmm31 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + __m512 selector5; + __m512 selector6; + + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0 ) ); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 1 ) ); + zmm3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 2 ) ); + zmm4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 3 ) ); + selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 4 ) ); + selector6 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 5 ) ); + + // c[0,0-15] + zmm8 = _mm512_add_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_add_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_add_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_add_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_add_ps( zmm2, zmm12 ); + + // c[1, 16-31] + zmm13 = _mm512_add_ps( zmm2, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_add_ps( zmm2, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_add_ps( zmm2, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_add_ps( zmm3, zmm16 ); + + // c[2, 16-31] + zmm17 = _mm512_add_ps( zmm3, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_add_ps( zmm3, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_add_ps( zmm3, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_add_ps( zmm4, zmm20 ); + + // c[3, 16-31] + zmm21 = _mm512_add_ps( zmm4, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_add_ps( zmm4, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_add_ps( zmm4, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_add_ps( selector5, zmm24 ); + + // c[4, 16-31] + zmm25 = _mm512_add_ps( selector5, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_add_ps( selector5, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_add_ps( selector5, zmm27 ); + + // c[5,0-15] + zmm28 = _mm512_add_ps( selector6, zmm28 ); + + // c[5, 16-31] + zmm29 = _mm512_add_ps( selector6, zmm29 ); + + // c[5,32-47] + zmm30 = _mm512_add_ps( selector6, zmm30 ); + + // c[5,48-63] + zmm31 = _mm512_add_ps( selector6, zmm31 ); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps( zmm1, zmm8 ); + + // c[0, 16-31] + zmm9 = _mm512_max_ps( zmm1, zmm9 ); + + // c[0,32-47] + zmm10 = _mm512_max_ps( zmm1, zmm10 ); + + // c[0,48-63] + zmm11 = _mm512_max_ps( zmm1, zmm11 ); + + // c[1,0-15] + zmm12 = _mm512_max_ps( zmm1, zmm12 ); + + // c[1,16-31] + zmm13 = _mm512_max_ps( zmm1, zmm13 ); + + // c[1,32-47] + zmm14 = _mm512_max_ps( zmm1, zmm14 ); + + // c[1,48-63] + zmm15 = _mm512_max_ps( zmm1, zmm15 ); + + // c[2,0-15] + zmm16 = _mm512_max_ps( zmm1, zmm16 ); + + // c[2,16-31] + zmm17 = _mm512_max_ps( zmm1, zmm17 ); + + // c[2,32-47] + zmm18 = _mm512_max_ps( zmm1, zmm18 ); + + // c[2,48-63] + zmm19 = _mm512_max_ps( zmm1, zmm19 ); + + // c[3,0-15] + zmm20 = _mm512_max_ps( zmm1, zmm20 ); + + // c[3,16-31] + zmm21 = _mm512_max_ps( zmm1, zmm21 ); + + // c[3,32-47] + zmm22 = _mm512_max_ps( zmm1, zmm22 ); + + // c[3,48-63] + zmm23 = _mm512_max_ps( zmm1, zmm23 ); + + // c[4,0-15] + zmm24 = _mm512_max_ps( zmm1, zmm24 ); + + // c[4,16-31] + zmm25 = _mm512_max_ps( zmm1, zmm25 ); + + // c[4,32-47] + zmm26 = _mm512_max_ps( zmm1, zmm26 ); + + // c[4,48-63] + zmm27 = _mm512_max_ps( zmm1, zmm27 ); + + // c[5,0-15] + zmm28 = _mm512_max_ps( zmm1, zmm28 ); + + // c[5,16-31] + zmm29 = _mm512_max_ps( zmm1, zmm29 ); + + // c[5,32-47] + zmm30 = _mm512_max_ps( zmm1, zmm30 ); + + // c[5,48-63] + zmm31 = _mm512_max_ps( zmm1, zmm31 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64_OPS: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + // c[0, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm9) + + // c[0, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm10) + + // c[0, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm11) + + // c[1, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm12) + + // c[1, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm13) + + // c[1, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm14) + + // c[1, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm15) + + // c[2, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm16) + + // c[2, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm17) + + // c[2, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm18) + + // c[2, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm19) + + // c[3, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm20) + + // c[3, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm21) + + // c[3, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm22) + + // c[3, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm23) + + // c[4, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm24) + + // c[4, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm25) + + // c[4, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm26) + + // c[4, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm27) + + // c[5, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm28) + + // c[5, 16-31] + RELU_SCALE_OP_F32S_AVX512(zmm29) + + // c[5, 32-47] + RELU_SCALE_OP_F32S_AVX512(zmm30) + + // c[5, 48-63] + RELU_SCALE_OP_F32S_AVX512(zmm31) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_TANH_6x64_OPS: + { + __m512 dn, z, x, r2, r, x_tanh; + __m512i q; + + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, r, r2, x, z, dn, x_tanh, q) + + // c[0, 16-31] + GELU_TANH_F32S_AVX512(zmm9, r, r2, x, z, dn, x_tanh, q) + + // c[0, 32-47] + GELU_TANH_F32S_AVX512(zmm10, r, r2, x, z, dn, x_tanh, q) + + // c[0, 48-63] + GELU_TANH_F32S_AVX512(zmm11, r, r2, x, z, dn, x_tanh, q) + + // c[1, 0-15] + GELU_TANH_F32S_AVX512(zmm12, r, r2, x, z, dn, x_tanh, q) + + // c[1, 16-31] + GELU_TANH_F32S_AVX512(zmm13, r, r2, x, z, dn, x_tanh, q) + + // c[1, 32-47] + GELU_TANH_F32S_AVX512(zmm14, r, r2, x, z, dn, x_tanh, q) + + // c[1, 48-63] + GELU_TANH_F32S_AVX512(zmm15, r, r2, x, z, dn, x_tanh, q) + + // c[2, 0-15] + GELU_TANH_F32S_AVX512(zmm16, r, r2, x, z, dn, x_tanh, q) + + // c[2, 16-31] + GELU_TANH_F32S_AVX512(zmm17, r, r2, x, z, dn, x_tanh, q) + + // c[2, 32-47] + GELU_TANH_F32S_AVX512(zmm18, r, r2, x, z, dn, x_tanh, q) + + // c[2, 48-63] + GELU_TANH_F32S_AVX512(zmm19, r, r2, x, z, dn, x_tanh, q) + + // c[3, 0-15] + GELU_TANH_F32S_AVX512(zmm20, r, r2, x, z, dn, x_tanh, q) + + // c[3, 16-31] + GELU_TANH_F32S_AVX512(zmm21, r, r2, x, z, dn, x_tanh, q) + + // c[3, 32-47] + GELU_TANH_F32S_AVX512(zmm22, r, r2, x, z, dn, x_tanh, q) + + // c[3, 48-63] + GELU_TANH_F32S_AVX512(zmm23, r, r2, x, z, dn, x_tanh, q) + + // c[4, 0-15] + GELU_TANH_F32S_AVX512(zmm24, r, r2, x, z, dn, x_tanh, q) + + // c[4, 16-31] + GELU_TANH_F32S_AVX512(zmm25, r, r2, x, z, dn, x_tanh, q) + + // c[4, 32-47] + GELU_TANH_F32S_AVX512(zmm26, r, r2, x, z, dn, x_tanh, q) + + // c[4, 48-63] + GELU_TANH_F32S_AVX512(zmm27, r, r2, x, z, dn, x_tanh, q) + + // c[5, 0-15] + GELU_TANH_F32S_AVX512(zmm28, r, r2, x, z, dn, x_tanh, q) + + // c[5, 16-31] + GELU_TANH_F32S_AVX512(zmm29, r, r2, x, z, dn, x_tanh, q) + + // c[5, 32-47] + GELU_TANH_F32S_AVX512(zmm30, r, r2, x, z, dn, x_tanh, q) + + // c[5, 48-63] + GELU_TANH_F32S_AVX512(zmm31, r, r2, x, z, dn, x_tanh, q) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_GELU_ERF_6x64_OPS: + { + __m512 x, r, x_erf; + + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, r, x, x_erf) + + // c[0, 16-31] + GELU_ERF_F32S_AVX512(zmm9, r, x, x_erf) + + // c[0, 32-47] + GELU_ERF_F32S_AVX512(zmm10, r, x, x_erf) + + // c[0, 48-63] + GELU_ERF_F32S_AVX512(zmm11, r, x, x_erf) + + // c[1, 0-15] + GELU_ERF_F32S_AVX512(zmm12, r, x, x_erf) + + // c[1, 16-31] + GELU_ERF_F32S_AVX512(zmm13, r, x, x_erf) + + // c[1, 32-47] + GELU_ERF_F32S_AVX512(zmm14, r, x, x_erf) + + // c[1, 48-63] + GELU_ERF_F32S_AVX512(zmm15, r, x, x_erf) + + // c[2, 0-15] + GELU_ERF_F32S_AVX512(zmm16, r, x, x_erf) + + // c[2, 16-31] + GELU_ERF_F32S_AVX512(zmm17, r, x, x_erf) + + // c[2, 32-47] + GELU_ERF_F32S_AVX512(zmm18, r, x, x_erf) + + // c[2, 48-63] + GELU_ERF_F32S_AVX512(zmm19, r, x, x_erf) + + // c[3, 0-15] + GELU_ERF_F32S_AVX512(zmm20, r, x, x_erf) + + // c[3, 16-31] + GELU_ERF_F32S_AVX512(zmm21, r, x, x_erf) + + // c[3, 32-47] + GELU_ERF_F32S_AVX512(zmm22, r, x, x_erf) + + // c[3, 48-63] + GELU_ERF_F32S_AVX512(zmm23, r, x, x_erf) + + // c[4, 0-15] + GELU_ERF_F32S_AVX512(zmm24, r, x, x_erf) + + // c[4, 16-31] + GELU_ERF_F32S_AVX512(zmm25, r, x, x_erf) + + // c[4, 32-47] + GELU_ERF_F32S_AVX512(zmm26, r, x, x_erf) + + // c[4, 48-63] + GELU_ERF_F32S_AVX512(zmm27, r, x, x_erf) + + // c[5, 0-15] + GELU_ERF_F32S_AVX512(zmm28, r, x, x_erf) + + // c[5, 16-31] + GELU_ERF_F32S_AVX512(zmm29, r, x, x_erf) + + // c[5, 32-47] + GELU_ERF_F32S_AVX512(zmm30, r, x, x_erf) + + // c[5, 48-63] + GELU_ERF_F32S_AVX512(zmm31, r, x, x_erf) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_CLIP_6x64_OPS: + { + __m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 ); + __m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 ); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, min, max) + + // c[0, 16-31] + CLIP_F32S_AVX512(zmm9, min, max) + + // c[0, 32-47] + CLIP_F32S_AVX512(zmm10, min, max) + + // c[0, 48-63] + CLIP_F32S_AVX512(zmm11, min, max) + + // c[1, 0-15] + CLIP_F32S_AVX512(zmm12, min, max) + + // c[1, 16-31] + CLIP_F32S_AVX512(zmm13, min, max) + + // c[1, 32-47] + CLIP_F32S_AVX512(zmm14, min, max) + + // c[1, 48-63] + CLIP_F32S_AVX512(zmm15, min, max) + + // c[2, 0-15] + CLIP_F32S_AVX512(zmm16, min, max) + + // c[2, 16-31] + CLIP_F32S_AVX512(zmm17, min, max) + + // c[2, 32-47] + CLIP_F32S_AVX512(zmm18, min, max) + + // c[2, 48-63] + CLIP_F32S_AVX512(zmm19, min, max) + + // c[3, 0-15] + CLIP_F32S_AVX512(zmm20, min, max) + + // c[3, 16-31] + CLIP_F32S_AVX512(zmm21, min, max) + + // c[3, 32-47] + CLIP_F32S_AVX512(zmm22, min, max) + + // c[3, 48-63] + CLIP_F32S_AVX512(zmm23, min, max) + + // c[4, 0-15] + CLIP_F32S_AVX512(zmm24, min, max) + + // c[4, 16-31] + CLIP_F32S_AVX512(zmm25, min, max) + + // c[4, 32-47] + CLIP_F32S_AVX512(zmm26, min, max) + + // c[4, 48-63] + CLIP_F32S_AVX512(zmm27, min, max) + + // c[5, 0-15] + CLIP_F32S_AVX512(zmm28, min, max) + + // c[5, 16-31] + CLIP_F32S_AVX512(zmm29, min, max) + + // c[5, 32-47] + CLIP_F32S_AVX512(zmm30, min, max) + + // c[5, 48-63] + CLIP_F32S_AVX512(zmm31, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,24,25,26,27,zmm1,zmm2,zmm3,zmm4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,28,29,30,31,zmm1,zmm2,zmm3,zmm4,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x64_OPS: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + // It is expected the post-op matrix arg has the same storage + // order as the output C matrix. + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,8,9,10,11,zmm1,zmm2,zmm3,zmm4,0); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,12,13,14,15,zmm1,zmm2,zmm3,zmm4,1); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,16,17,18,19,zmm1,zmm2,zmm3,zmm4,2); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,20,21,22,23,zmm1,zmm2,zmm3,zmm4,3); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,24,25,26,27,zmm1,zmm2,zmm3,zmm4,4); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,28,29,30,31,zmm1,zmm2,zmm3,zmm4,5); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64_OPS: + { + zmm1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512 al_in, r, r2, z, dn; + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(zmm27, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(zmm28, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(zmm29, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(zmm30, zmm1, al_in, r, r2, z, dn, ex_out); + + // c[5, 48-63] + SWISH_F32_AVX512_DEF(zmm31, zmm1, al_in, r, r2, z, dn, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_OPS_DISABLE: + ; + + // Case where the output C matrix is float + // Store the results. + // c[0,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm8 ); + // c[0,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm9 ); + // c[0,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm10 ); + // c[0,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 0 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm11 ); + + // c[1,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm12 ); + // c[1,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm13 ); + // c[1,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm14 ); + // c[1,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 1 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm15 ); + + // c[2,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm16 ); + // c[2,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm17 ); + // c[2,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm18 ); + // c[2,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 2 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm19 ); + + // c[3,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm20 ); + // c[3,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm21 ); + // c[3,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm22 ); + // c[3,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 3 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm23 ); + + // c[4,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm24 ); + // c[4,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm25 ); + // c[4,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm26 ); + // c[4,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 4 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm27 ); + + // c[5,0-15] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 0 ) ), k0, zmm28 ); + // c[5,16-31] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 16 ) ), k1, zmm29 ); + // c[5,32-47] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 32 ) ), k2, zmm30 ); + // c[5,48-63] + _mm512_mask_storeu_ps( b + ( rs_b * ( ir + 5 ) ) + + ( cs_b * ( jr + 48 ) ), k3, zmm31 ); + + post_ops_attr.post_op_c_j += NR_L; + } + post_ops_attr.post_op_c_j = orig_post_op_c_j; + post_ops_attr.post_op_c_i += MR; + } + + if ( m_partial_pieces > 0 ) + { + dim_t dsize = sizeof( float ); + int8_t* b_i = ( int8_t* )b; + if ( m_partial_pieces == 5 ) + { + lpgemm_eltwise_ops_kernel_f32of32_5x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 4 ) + { + lpgemm_eltwise_ops_kernel_f32of32_4x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 3 ) + { + lpgemm_eltwise_ops_kernel_f32of32_3x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 2 ) + { + lpgemm_eltwise_ops_kernel_f32of32_2x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + else if ( m_partial_pieces == 1 ) + { + lpgemm_eltwise_ops_kernel_f32of32_1x64 + ( + n0, + a + ( rs_a * m_full_pieces_loop_limit ), rs_a, cs_a, + ( float* )( b_i + ( dsize * rs_b * m_full_pieces_loop_limit ) ), + rs_b, cs_b, + post_ops_list, post_ops_attr + ); + } + } +} + +#endif diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_fringe_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemm_fringe_f32_avx512.c index 70ac7f9b90..f9c3d7236d 100644 --- a/kernels/zen4/lpgemm/f32f32f32/lpgemm_fringe_f32_avx512.c +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_fringe_f32_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64) &&POST_OPS_RELU_SCALE_5x64F, &&POST_OPS_GELU_TANH_5x64F, &&POST_OPS_GELU_ERF_5x64F, - &&POST_OPS_CLIP_5x64F + &&POST_OPS_CLIP_5x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x64F, + &&POST_OPS_SWISH_5x64F, + &&POST_OPS_MATRIX_MUL_5x64F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -134,57 +138,57 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); zmm15 = _mm512_fmadd_ps(zmm1, zmm3, zmm15); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); zmm19 = _mm512_fmadd_ps(zmm1, zmm3, zmm19); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); zmm23 = _mm512_fmadd_ps(zmm1, zmm3, zmm23); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm26 = _mm512_fmadd_ps(zmm0, zmm3, zmm26); zmm27 = _mm512_fmadd_ps(zmm1, zmm3, zmm27); } @@ -689,6 +693,118 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,4,24,25,26,27); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,4,24,25,26,27); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(zmm27, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x64F_DISABLE: ; @@ -728,7 +844,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64) &&POST_OPS_RELU_SCALE_4x64F, &&POST_OPS_GELU_TANH_4x64F, &&POST_OPS_GELU_ERF_4x64F, - &&POST_OPS_CLIP_4x64F + &&POST_OPS_CLIP_4x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x64F, + &&POST_OPS_SWISH_4x64F, + &&POST_OPS_MATRIX_MUL_4x64F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -754,11 +874,14 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64) /*Load 32 elements from row0 of B*/ zmm0 = _mm512_loadu_ps (bbuf ); //load 0-15 values from current row zmm1 = _mm512_loadu_ps (bbuf + 16); //load 16-31 values from current row - + zmm0 = _mm512_shuffle_ps(zmm0, zmm0, 0xE4); // dummy shuffle + zmm1 = _mm512_shuffle_ps(zmm1, zmm1, 0xE4); // dummy shuffle /*Load Next 32 elements from row0 of B*/ zmm6 = _mm512_loadu_ps (bbuf + 32); //load 32-47 from current row zmm7 = _mm512_loadu_ps (bbuf + 48); //load 48-63 from current row - + zmm6 = _mm512_shuffle_ps(zmm6, zmm6, 0xE4); // dummy shuffle + zmm7 = _mm512_shuffle_ps(zmm7, zmm7, 0xE4); // dummy shuffle + /*Broadcast col0 elements of 12 rows of A*/ zmm2 = _mm512_set1_ps(*(abuf + 0*rs_a)); //broadcast c0r0 zmm3 = _mm512_set1_ps(*(abuf + 1*rs_a)); //broadcast c0r1 @@ -803,46 +926,46 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); zmm15 = _mm512_fmadd_ps(zmm1, zmm3, zmm15); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); zmm19 = _mm512_fmadd_ps(zmm1, zmm3, zmm19); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); zmm23 = _mm512_fmadd_ps(zmm1, zmm3, zmm23); } @@ -1260,6 +1383,100 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x64F_DISABLE: ; @@ -1294,7 +1511,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64) &&POST_OPS_RELU_SCALE_3x64F, &&POST_OPS_GELU_TANH_3x64F, &&POST_OPS_GELU_ERF_3x64F, - &&POST_OPS_CLIP_3x64F + &&POST_OPS_CLIP_3x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x64F, + &&POST_OPS_SWISH_3x64F, + &&POST_OPS_MATRIX_MUL_3x64F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1319,11 +1540,15 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64) /*Load 32 elements from row0 of B*/ zmm0 = _mm512_loadu_ps (bbuf ); //load 0-15 values from current row zmm1 = _mm512_loadu_ps (bbuf + 16); //load 16-31 values from current row + zmm0 = _mm512_shuffle_ps(zmm0, zmm0, 0xE4); // dummy shuffle + zmm1 = _mm512_shuffle_ps(zmm1, zmm1, 0xE4); // dummy shuffle /*Load Next 32 elements from row0 of B*/ zmm6 = _mm512_loadu_ps (bbuf + 32); //load 32-47 from current row zmm7 = _mm512_loadu_ps (bbuf + 48); //load 48-63 from current row - + zmm6 = _mm512_shuffle_ps(zmm6, zmm6, 0xE4); // dummy shuffle + zmm7 = _mm512_shuffle_ps(zmm7, zmm7, 0xE4); // dummy shuffle + /*Broadcast col0 elements of 12 rows of A*/ zmm2 = _mm512_set1_ps(*(abuf + 0*rs_a)); //broadcast c0r0 zmm3 = _mm512_set1_ps(*(abuf + 1*rs_a)); //broadcast c0r1 @@ -1361,35 +1586,35 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); zmm15 = _mm512_fmadd_ps(zmm1, zmm3, zmm15); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); zmm19 = _mm512_fmadd_ps(zmm1, zmm3, zmm19); } @@ -1720,6 +1945,82 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x64F_DISABLE: ; @@ -1749,7 +2050,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64) &&POST_OPS_RELU_SCALE_2x64F, &&POST_OPS_GELU_TANH_2x64F, &&POST_OPS_GELU_ERF_2x64F, - &&POST_OPS_CLIP_2x64F + &&POST_OPS_CLIP_2x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x64F, + &&POST_OPS_SWISH_2x64F, + &&POST_OPS_MATRIX_MUL_2x64F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1772,11 +2077,15 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64) /*Load 32 elements from row0 of B*/ zmm0 = _mm512_loadu_ps (bbuf ); //load 0-15 values from current row zmm1 = _mm512_loadu_ps (bbuf + 16); //load 16-31 values from current row + zmm0 = _mm512_shuffle_ps(zmm0, zmm0, 0xE4); // dummy shuffle + zmm1 = _mm512_shuffle_ps(zmm1, zmm1, 0xE4); // dummy shuffle /*Load Next 32 elements from row0 of B*/ zmm6 = _mm512_loadu_ps (bbuf + 32); //load 32-47 from current row zmm7 = _mm512_loadu_ps (bbuf + 48); //load 48-63 from current row - + zmm6 = _mm512_shuffle_ps(zmm6, zmm6, 0xE4); // dummy shuffle + zmm7 = _mm512_shuffle_ps(zmm7, zmm7, 0xE4); // dummy shuffle + /*Broadcast col0 elements of 12 rows of A*/ zmm2 = _mm512_set1_ps(*(abuf + 0*rs_a)); //broadcast c0r0 zmm3 = _mm512_set1_ps(*(abuf + 1*rs_a)); //broadcast c0r1 @@ -1807,24 +2116,24 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); zmm15 = _mm512_fmadd_ps(zmm1, zmm3, zmm15); } @@ -2068,6 +2377,64 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x64F_DISABLE: ; @@ -2092,7 +2459,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64) &&POST_OPS_RELU_SCALE_1x64F, &&POST_OPS_GELU_TANH_1x64F, &&POST_OPS_GELU_ERF_1x64F, - &&POST_OPS_CLIP_1x64F + &&POST_OPS_CLIP_1x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x64F, + &&POST_OPS_SWISH_1x64F, + &&POST_OPS_MATRIX_MUL_1x64F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2141,13 +2512,13 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(cbuf); - zmm1 = _mm512_load_ps(cbuf + 16); + zmm0 = _mm512_loadu_ps(cbuf); + zmm1 = _mm512_loadu_ps(cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(cbuf + 32); - zmm1 = _mm512_load_ps(cbuf + 48); + zmm0 = _mm512_loadu_ps(cbuf + 32); + zmm1 = _mm512_loadu_ps(cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); } @@ -2304,6 +2675,46 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x64F_DISABLE: ; @@ -2323,7 +2734,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48) &&POST_OPS_RELU_SCALE_5x48F, &&POST_OPS_GELU_TANH_5x48F, &&POST_OPS_GELU_ERF_5x48F, - &&POST_OPS_CLIP_5x48F + &&POST_OPS_CLIP_5x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x48F, + &&POST_OPS_SWISH_5x48F, + &&POST_OPS_MATRIX_MUL_5x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2400,48 +2815,48 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm26 = _mm512_fmadd_ps(zmm0, zmm3, zmm26); } @@ -2837,6 +3252,103 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,4,24,25,26); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,4,24,25,26); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x48F_DISABLE: ; @@ -2871,7 +3383,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48) &&POST_OPS_RELU_SCALE_4x48F, &&POST_OPS_GELU_TANH_4x48F, &&POST_OPS_GELU_ERF_4x48F, - &&POST_OPS_CLIP_4x48F + &&POST_OPS_CLIP_4x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x48F, + &&POST_OPS_SWISH_4x48F, + &&POST_OPS_MATRIX_MUL_4x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -2940,39 +3456,39 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); } @@ -3302,6 +3818,88 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x48F_DISABLE: ; @@ -3332,7 +3930,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48) &&POST_OPS_RELU_SCALE_3x48F, &&POST_OPS_GELU_TANH_3x48F, &&POST_OPS_GELU_ERF_3x48F, - &&POST_OPS_CLIP_3x48F + &&POST_OPS_CLIP_3x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x48F, + &&POST_OPS_SWISH_3x48F, + &&POST_OPS_MATRIX_MUL_3x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3395,30 +3997,30 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); } @@ -3682,6 +4284,73 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x48F_DISABLE: ; @@ -3708,7 +4377,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48) &&POST_OPS_RELU_SCALE_2x48F, &&POST_OPS_GELU_TANH_2x48F, &&POST_OPS_GELU_ERF_2x48F, - &&POST_OPS_CLIP_2x48F + &&POST_OPS_CLIP_2x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x48F, + &&POST_OPS_SWISH_2x48F, + &&POST_OPS_MATRIX_MUL_2x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -3763,21 +4436,21 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); } @@ -3975,6 +4648,58 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x48F_DISABLE: ; @@ -3997,7 +4722,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48) &&POST_OPS_RELU_SCALE_1x48F, &&POST_OPS_GELU_TANH_1x48F, &&POST_OPS_GELU_ERF_1x48F, - &&POST_OPS_CLIP_1x48F + &&POST_OPS_CLIP_1x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x48F, + &&POST_OPS_SWISH_1x48F, + &&POST_OPS_MATRIX_MUL_1x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4043,12 +4772,12 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(cbuf); - zmm1 = _mm512_load_ps(cbuf + 16); + zmm0 = _mm512_loadu_ps(cbuf); + zmm1 = _mm512_loadu_ps(cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(cbuf + 32); + zmm0 = _mm512_loadu_ps(cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); } @@ -4180,6 +4909,43 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x48F_DISABLE: ; @@ -4198,7 +4964,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32) &&POST_OPS_RELU_SCALE_5x32F, &&POST_OPS_GELU_TANH_5x32F, &&POST_OPS_GELU_ERF_5x32F, - &&POST_OPS_CLIP_5x32F + &&POST_OPS_CLIP_5x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_5x32F, + &&POST_OPS_SWISH_5x32F, + &&POST_OPS_MATRIX_MUL_5x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4266,32 +5036,32 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); } @@ -4580,6 +5350,88 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_5x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,3,20,21); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,4,24,25); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_5x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,3,20,21); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,4,24,25); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x32F_DISABLE: ; @@ -4609,7 +5461,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32) &&POST_OPS_RELU_SCALE_4x32F, &&POST_OPS_GELU_TANH_4x32F, &&POST_OPS_GELU_ERF_4x32F, - &&POST_OPS_CLIP_4x32F + &&POST_OPS_CLIP_4x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_4x32F, + &&POST_OPS_SWISH_4x32F, + &&POST_OPS_MATRIX_MUL_4x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -4669,26 +5525,26 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); } @@ -4932,6 +5788,76 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_4x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,3,20,21); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_4x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,3,20,21); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x32F_DISABLE: ; @@ -4958,7 +5884,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32) &&POST_OPS_RELU_SCALE_3x32F, &&POST_OPS_GELU_TANH_3x32F, &&POST_OPS_GELU_ERF_3x32F, - &&POST_OPS_CLIP_3x32F + &&POST_OPS_CLIP_3x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_3x32F, + &&POST_OPS_SWISH_3x32F, + &&POST_OPS_MATRIX_MUL_3x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5014,20 +5944,20 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf+16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf+16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); _cbuf += rs_c; @@ -5227,6 +6157,64 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_3x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,2,16,17); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_3x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,2,16,17); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x32F_DISABLE: ; @@ -5250,7 +6238,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32) &&POST_OPS_RELU_SCALE_2x32F, &&POST_OPS_GELU_TANH_2x32F, &&POST_OPS_GELU_ERF_2x32F, - &&POST_OPS_CLIP_2x32F + &&POST_OPS_CLIP_2x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_2x32F, + &&POST_OPS_SWISH_2x32F, + &&POST_OPS_MATRIX_MUL_2x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5299,14 +6291,14 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); } @@ -5460,6 +6452,52 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_2x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,1,12,13); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_2x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,1,12,13); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x32F_DISABLE: ; @@ -5480,7 +6518,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32) &&POST_OPS_RELU_SCALE_1x32F, &&POST_OPS_GELU_TANH_1x32F, &&POST_OPS_GELU_ERF_1x32F, - &&POST_OPS_CLIP_1x32F + &&POST_OPS_CLIP_1x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_1x32F, + &&POST_OPS_SWISH_1x32F, + &&POST_OPS_MATRIX_MUL_1x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -5523,8 +6565,8 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(cbuf); - zmm1 = _mm512_load_ps(cbuf + 16); + zmm0 = _mm512_loadu_ps(cbuf); + zmm1 = _mm512_loadu_ps(cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); } @@ -5633,6 +6675,40 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_1x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_1x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x32F_DISABLE: ; diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h b/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h index f24bca9e1f..1c1bc2a338 100644 --- a/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,7 @@ #define LPGEMM_F32_SGEMM_KERN_MACROS_H #include "../gelu_avx512.h" +#include "../silu_avx512.h" #include "../math_utils_avx512.h" /* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */ @@ -67,12 +68,117 @@ zmm2 = _mm512_setzero_ps(); \ zmm3 = _mm512_setzero_ps(); +// Zero-out the given ZMM accumulator registers +#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \ + xmm0 = _mm_setzero_ps(); \ + xmm1 = _mm_setzero_ps(); \ + xmm2 = _mm_setzero_ps(); \ + xmm3 = _mm_setzero_ps(); + /*Multiply alpha with accumulator registers and store back*/ #define ALPHA_MUL_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3,alpha) \ zmm0 = _mm512_mul_ps(zmm0,alpha); \ zmm1 = _mm512_mul_ps(zmm1,alpha); \ zmm2 = _mm512_mul_ps(zmm2,alpha); \ zmm3 = _mm512_mul_ps(zmm3,alpha); - + +// Matrix Add post-ops helper macros +#define F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + zmm ## r_ind0 = _mm512_add_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_add_ps( scr1, zmm ## r_ind1 ); \ + +#define F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2) \ + zmm ## r_ind0 = _mm512_add_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_add_ps( scr1, zmm ## r_ind1 ); \ + zmm ## r_ind2 = _mm512_add_ps( scr2, zmm ## r_ind2 ); \ + +#define F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \ + zmm ## r_ind0 = _mm512_add_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_add_ps( scr1, zmm ## r_ind1 ); \ + zmm ## r_ind2 = _mm512_add_ps( scr2, zmm ## r_ind2 ); \ + zmm ## r_ind3 = _mm512_add_ps( scr3, zmm ## r_ind3 ); \ + +#define F32_F32_MATRIX_ADD_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_maskz_loadu_ps \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ); \ + +#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \ + +#define F32_F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \ + +#define F32_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \ + +#define F32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3,scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_ADD_LOAD(k0,scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(k1,scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(k2,scr2,m_ind,2); \ + F32_F32_MATRIX_ADD_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \ + +// Matrix Mul post-ops helper macros +#define F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + zmm ## r_ind0 = _mm512_mul_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_mul_ps( scr1, zmm ## r_ind1 ); \ + +#define F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2) \ + zmm ## r_ind0 = _mm512_mul_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_mul_ps( scr1, zmm ## r_ind1 ); \ + zmm ## r_ind2 = _mm512_mul_ps( scr2, zmm ## r_ind2 ); \ + +#define F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \ + zmm ## r_ind0 = _mm512_mul_ps( scr0, zmm ## r_ind0 ); \ + zmm ## r_ind1 = _mm512_mul_ps( scr1, zmm ## r_ind1 ); \ + zmm ## r_ind2 = _mm512_mul_ps( scr2, zmm ## r_ind2 ); \ + zmm ## r_ind3 = _mm512_mul_ps( scr3, zmm ## r_ind3 ); \ + +#define F32_F32_MATRIX_MUL_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_maskz_loadu_ps \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ); \ + +#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \ + +#define F32_F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2) \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \ + +#define F32_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + F32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \ + +#define F32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3,scr0,scr1,scr2,scr3,m_ind) \ + F32_F32_MATRIX_MUL_LOAD(k0,scr0,m_ind,0); \ + F32_F32_MATRIX_MUL_LOAD(k1,scr1,m_ind,1); \ + F32_F32_MATRIX_MUL_LOAD(k2,scr2,m_ind,2); \ + F32_F32_MATRIX_MUL_LOAD(k3,scr3,m_ind,3); \ + F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \ + #endif //LPGEMM_F32_SGEMM_KERN_MACROS_H diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx512.c index d1d14209ba..b985624a28 100644 --- a/kernels/zen4/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx512.c +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,7 +52,11 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m) &&POST_OPS_RELU_SCALE_6x64F, &&POST_OPS_GELU_TANH_6x64F, &&POST_OPS_GELU_ERF_6x64F, - &&POST_OPS_CLIP_6x64F + &&POST_OPS_CLIP_6x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x64F, + &&POST_OPS_SWISH_6x64F, + &&POST_OPS_MATRIX_MUL_6x64F }; uint64_t n_left = n0 % 64; //n0 is expected to be n0<=NR @@ -295,72 +299,71 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); zmm11 = _mm512_fmadd_ps(zmm1, zmm3, zmm11); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); zmm15 = _mm512_fmadd_ps(zmm1, zmm3, zmm15); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); zmm19 = _mm512_fmadd_ps(zmm1, zmm3, zmm19); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); zmm23 = _mm512_fmadd_ps(zmm1, zmm3, zmm23); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm26 = _mm512_fmadd_ps(zmm0, zmm3, zmm26); zmm27 = _mm512_fmadd_ps(zmm1, zmm3, zmm27); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm28 = _mm512_fmadd_ps(zmm0, zmm3, zmm28); zmm29 = _mm512_fmadd_ps(zmm1, zmm3, zmm29); - zmm0 = _mm512_load_ps(_cbuf + 32); - zmm1 = _mm512_load_ps(_cbuf + 48); + zmm0 = _mm512_loadu_ps(_cbuf + 32); + zmm1 = _mm512_loadu_ps(_cbuf + 48); zmm30 = _mm512_fmadd_ps(zmm0, zmm3, zmm30); zmm31 = _mm512_fmadd_ps(zmm1, zmm3, zmm31); } - // Post Ops lpgemm_post_op* post_ops_list_temp = post_ops_list; POST_OP_LABEL_LASTK_SAFE_JUMP @@ -948,6 +951,135 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,4,24,25,26,27); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_ADD_4COL(zmm1,zmm2,zmm3,zmm4,5,28,29,30,31); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x64F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + // c[0:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,0,8,9,10,11); + + // c[1:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,1,12,13,14,15); + + // c[2:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,2,16,17,18,19); + + // c[3:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,3,20,21,22,23); + + // c[4:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,4,24,25,26,27); + + // c[5:0-15,16-31,32-47,48-63] + F32_F32_MATRIX_MUL_4COL(zmm1,zmm2,zmm3,zmm4,5,28,29,30,31); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 48-63] + SWISH_F32_AVX512_DEF(zmm11, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 48-63] + SWISH_F32_AVX512_DEF(zmm15, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 48-63] + SWISH_F32_AVX512_DEF(zmm19, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 48-63] + SWISH_F32_AVX512_DEF(zmm23, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 48-63] + SWISH_F32_AVX512_DEF(zmm27, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(zmm28, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(zmm29, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(zmm30, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 48-63] + SWISH_F32_AVX512_DEF(zmm31, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x64F_DISABLE: ; @@ -1030,7 +1162,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m) &&POST_OPS_RELU_SCALE_6x48F, &&POST_OPS_GELU_TANH_6x48F, &&POST_OPS_GELU_ERF_6x48F, - &&POST_OPS_CLIP_6x48F + &&POST_OPS_CLIP_6x48F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x48F, + &&POST_OPS_SWISH_6x48F, + &&POST_OPS_MATRIX_MUL_6x48F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1134,57 +1270,57 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm14 = _mm512_fmadd_ps(zmm0, zmm3, zmm14); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm18 = _mm512_fmadd_ps(zmm0, zmm3, zmm18); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm22 = _mm512_fmadd_ps(zmm0, zmm3, zmm22); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm26 = _mm512_fmadd_ps(zmm0, zmm3, zmm26); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm28 = _mm512_fmadd_ps(zmm0, zmm3, zmm28); zmm29 = _mm512_fmadd_ps(zmm1, zmm3, zmm29); - zmm0 = _mm512_load_ps(_cbuf + 32); + zmm0 = _mm512_loadu_ps(_cbuf + 32); zmm30 = _mm512_fmadd_ps(zmm0, zmm3, zmm30); } @@ -1646,6 +1782,118 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,4,24,25,26); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_ADD_3COL(zmm1,zmm2,zmm3,5,28,29,30); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x48F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,0,8,9,10); + + // c[1:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,1,12,13,14); + + // c[2:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,2,16,17,18); + + // c[3:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,3,20,21,22); + + // c[4:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,4,24,25,26); + + // c[5:0-15,16-31,32-47] + F32_F32_MATRIX_MUL_3COL(zmm1,zmm2,zmm3,5,28,29,30); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x48F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 32-47] + SWISH_F32_AVX512_DEF(zmm10, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 32-47] + SWISH_F32_AVX512_DEF(zmm14, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 32-47] + SWISH_F32_AVX512_DEF(zmm18, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 32-47] + SWISH_F32_AVX512_DEF(zmm22, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 32-47] + SWISH_F32_AVX512_DEF(zmm26, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(zmm28, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(zmm29, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 32-47] + SWISH_F32_AVX512_DEF(zmm30, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48F_DISABLE: ; @@ -1722,7 +1970,11 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m) &&POST_OPS_RELU_SCALE_6x32F, &&POST_OPS_GELU_TANH_6x32F, &&POST_OPS_GELU_ERF_6x32F, - &&POST_OPS_CLIP_6x32F + &&POST_OPS_CLIP_6x32F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x32F, + &&POST_OPS_SWISH_6x32F, + &&POST_OPS_MATRIX_MUL_6x32F }; // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. @@ -1806,38 +2058,38 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m) //add to accumulator and store back zmm3 = _mm512_set1_ps(beta); - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm12 = _mm512_fmadd_ps(zmm0, zmm3, zmm12); zmm13 = _mm512_fmadd_ps(zmm1, zmm3, zmm13); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm16 = _mm512_fmadd_ps(zmm0, zmm3, zmm16); zmm17 = _mm512_fmadd_ps(zmm1, zmm3, zmm17); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm20 = _mm512_fmadd_ps(zmm0, zmm3, zmm20); zmm21 = _mm512_fmadd_ps(zmm1, zmm3, zmm21); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm24 = _mm512_fmadd_ps(zmm0, zmm3, zmm24); zmm25 = _mm512_fmadd_ps(zmm1, zmm3, zmm25); _cbuf += rs_c; - zmm0 = _mm512_load_ps(_cbuf); - zmm1 = _mm512_load_ps(_cbuf + 16); + zmm0 = _mm512_loadu_ps(_cbuf); + zmm1 = _mm512_loadu_ps(_cbuf + 16); zmm28 = _mm512_fmadd_ps(zmm0, zmm3, zmm28); zmm29 = _mm512_fmadd_ps(zmm1, zmm3, zmm29); } @@ -2171,6 +2423,100 @@ LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,3,20,21); + + // c[4:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,4,24,25); + + // c[5:0-15,16-31] + F32_F32_MATRIX_ADD_2COL(zmm1,zmm2,5,28,29); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_MUL_6x32F: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + float* matptr = ( float* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,0,8,9); + + // c[1:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,1,12,13); + + // c[2:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,2,16,17); + + // c[3:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,3,20,21); + + // c[4:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,4,24,25); + + // c[5:0-15,16-31] + F32_F32_MATRIX_MUL_2COL(zmm1,zmm2,5,28,29); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32F: + { + __m512 zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[0, 16-31] + SWISH_F32_AVX512_DEF(zmm9, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 16-31] + SWISH_F32_AVX512_DEF(zmm13, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 16-31] + SWISH_F32_AVX512_DEF(zmm17, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 16-31] + SWISH_F32_AVX512_DEF(zmm21, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 0-15] + SWISH_F32_AVX512_DEF(zmm24, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[4, 16-31] + SWISH_F32_AVX512_DEF(zmm25, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 0-15] + SWISH_F32_AVX512_DEF(zmm28, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[5, 16-31] + SWISH_F32_AVX512_DEF(zmm29, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x32F_DISABLE: ; diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_pack_a_f32_amd512vnni.c b/kernels/zen4/lpgemm/f32f32f32/lpgemm_pack_a_f32_amd512vnni.c new file mode 100644 index 0000000000..631bef66c5 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_pack_a_f32_amd512vnni.c @@ -0,0 +1,762 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binarsy form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#define UNPACKLO_PS16 \ + b_reg[0] = _mm512_unpacklo_ps(a_reg[0], a_reg[1]); \ + b_reg[1] = _mm512_unpacklo_ps(a_reg[2], a_reg[3]); \ + b_reg[2] = _mm512_unpacklo_ps(a_reg[4], a_reg[5]); \ + b_reg[3] = _mm512_unpacklo_ps(a_reg[6], a_reg[7]); \ + b_reg[4] = _mm512_unpacklo_ps(a_reg[8], a_reg[9]); \ + b_reg[5] = _mm512_unpacklo_ps(a_reg[10], a_reg[11]); \ + b_reg[6] = _mm512_unpacklo_ps(a_reg[12], a_reg[13]); \ + b_reg[7] = _mm512_unpacklo_ps(a_reg[14], a_reg[15]); + +#define UNPACKHI_PS16 \ + b_reg[8] = _mm512_unpackhi_ps(a_reg[0], a_reg[1]); \ + b_reg[9] = _mm512_unpackhi_ps(a_reg[2], a_reg[3]); \ + b_reg[10] = _mm512_unpackhi_ps(a_reg[4], a_reg[5]); \ + b_reg[11] = _mm512_unpackhi_ps(a_reg[6], a_reg[7]); \ + b_reg[12] = _mm512_unpackhi_ps(a_reg[8], a_reg[9]); \ + b_reg[13] = _mm512_unpackhi_ps(a_reg[10], a_reg[11]); \ + b_reg[14] = _mm512_unpackhi_ps(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm512_unpackhi_ps(a_reg[14], a_reg[15]); + +#define SHUFFLE_64x2 \ + a_reg[0] = _mm512_shuffle_ps(b_reg[0], b_reg[1], 0x44); \ + a_reg[1] = _mm512_shuffle_ps(b_reg[0], b_reg[1], 0xEE); \ + a_reg[2] = _mm512_shuffle_ps(b_reg[2], b_reg[3], 0x44); \ + a_reg[3] = _mm512_shuffle_ps(b_reg[2], b_reg[3], 0xEE); \ +\ + a_reg[4] = _mm512_shuffle_ps(b_reg[4], b_reg[5], 0x44); \ + a_reg[5] = _mm512_shuffle_ps(b_reg[4], b_reg[5], 0xEE); \ + a_reg[6] = _mm512_shuffle_ps(b_reg[6], b_reg[7], 0x44); \ + a_reg[7] = _mm512_shuffle_ps(b_reg[6], b_reg[7], 0xEE); \ +\ + a_reg[8] = _mm512_shuffle_ps(b_reg[8], b_reg[9], 0x44); \ + a_reg[9] = _mm512_shuffle_ps(b_reg[8], b_reg[9], 0xEE); \ + a_reg[10] = _mm512_shuffle_ps(b_reg[10], b_reg[11], 0x44); \ + a_reg[11] = _mm512_shuffle_ps(b_reg[10], b_reg[11], 0xEE); \ +\ + a_reg[12] = _mm512_shuffle_ps(b_reg[12], b_reg[13], 0x44); \ + a_reg[13] = _mm512_shuffle_ps(b_reg[12], b_reg[13], 0xEE); \ + a_reg[14] = _mm512_shuffle_ps(b_reg[14], b_reg[15], 0x44); \ + a_reg[15] = _mm512_shuffle_ps(b_reg[14], b_reg[15], 0xEE); + +#define MASKED_STORE_PS(mask) \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+0) * KC + kr ), mask, a_reg[0]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+1) * KC + kr ), mask, a_reg[1]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+2) * KC + kr ), mask, a_reg[2]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+3) * KC + kr ), mask, a_reg[3]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+4) * KC + kr ), mask, a_reg[4]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+5) * KC + kr ), mask, a_reg[5]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+6) * KC + kr ), mask, a_reg[6]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+7) * KC + kr ), mask, a_reg[7]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+8) * KC + kr ), mask, a_reg[8]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+9) * KC + kr ), mask, a_reg[9]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+10) * KC + kr ), mask, a_reg[10]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+11) * KC + kr ), mask, a_reg[11]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+12) * KC + kr ), mask, a_reg[12]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+13) * KC + kr ), mask, a_reg[13]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+14) * KC + kr ), mask, a_reg[14]); \ + _mm512_mask_storeu_ps((pack_a_buffer + (ic+15) * KC + kr ), mask, a_reg[15]); + +#define PERMUTE4x4( mask1, mask2 ) \ + b_reg[0] = _mm512_permutex2var_ps( a_reg[0], mask1, a_reg[2] ); \ + b_reg[1] = _mm512_permutex2var_ps( a_reg[1], mask1, a_reg[3] ); \ + b_reg[2] = _mm512_permutex2var_ps( a_reg[8], mask1, a_reg[10] ); \ + b_reg[3] = _mm512_permutex2var_ps( a_reg[9], mask1, a_reg[11] ); \ +\ + b_reg[4] = _mm512_permutex2var_ps( a_reg[4], mask1, a_reg[6]); \ + b_reg[5] = _mm512_permutex2var_ps( a_reg[5], mask1, a_reg[7]); \ + b_reg[6] = _mm512_permutex2var_ps( a_reg[12], mask1, a_reg[14]); \ + b_reg[7] = _mm512_permutex2var_ps( a_reg[13], mask1, a_reg[15]); \ +\ + b_reg[8] = _mm512_permutex2var_ps( a_reg[0], mask2, a_reg[2]); \ + b_reg[9] = _mm512_permutex2var_ps( a_reg[1], mask2, a_reg[3]); \ + b_reg[10] = _mm512_permutex2var_ps( a_reg[8], mask2, a_reg[10]); \ + b_reg[11] = _mm512_permutex2var_ps( a_reg[9], mask2, a_reg[11]); \ +\ + b_reg[12] = _mm512_permutex2var_ps( a_reg[4], mask2, a_reg[6]); \ + b_reg[13] = _mm512_permutex2var_ps( a_reg[5], mask2, a_reg[7]); \ + b_reg[14] = _mm512_permutex2var_ps( a_reg[12], mask2, a_reg[14]); \ + b_reg[15] = _mm512_permutex2var_ps( a_reg[13], mask2, a_reg[15]); + +#define PERMUTE8x8( mask3, mask4 ) \ + a_reg[0] = _mm512_permutex2var_ps( b_reg[0], mask3, b_reg[4]); \ + a_reg[1] = _mm512_permutex2var_ps( b_reg[1], mask3, b_reg[5]); \ + a_reg[2] = _mm512_permutex2var_ps( b_reg[2], mask3, b_reg[6]); \ + a_reg[3] = _mm512_permutex2var_ps( b_reg[3], mask3, b_reg[7]); \ +\ + a_reg[4] = _mm512_permutex2var_ps( b_reg[0], mask4, b_reg[4]); \ + a_reg[5] = _mm512_permutex2var_ps( b_reg[1], mask4, b_reg[5]); \ + a_reg[6] = _mm512_permutex2var_ps( b_reg[2], mask4, b_reg[6]); \ + a_reg[7] = _mm512_permutex2var_ps( b_reg[3], mask4, b_reg[7]); \ +\ + a_reg[8] = _mm512_permutex2var_ps( b_reg[8], mask3, b_reg[12]); \ + a_reg[9] = _mm512_permutex2var_ps( b_reg[9], mask3, b_reg[13]); \ + a_reg[10] = _mm512_permutex2var_ps( b_reg[10], mask3, b_reg[14]); \ + a_reg[11] = _mm512_permutex2var_ps( b_reg[11], mask3, b_reg[15]); \ +\ + a_reg[12] = _mm512_permutex2var_ps( b_reg[8], mask4, b_reg[12]); \ + a_reg[13] = _mm512_permutex2var_ps( b_reg[9], mask4, b_reg[13]); \ + a_reg[14] = _mm512_permutex2var_ps( b_reg[10], mask4, b_reg[14]); \ + a_reg[15] = _mm512_permutex2var_ps( b_reg[11], mask4, b_reg[15]); + +void packa_mr16_f32f32f32of32_col_major +( + float* pack_a_buffer, + const float* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p +) +{ + dim_t MR = 16; + dim_t ic, kr; + dim_t m_left = MC % 4; + + __m512 a_reg[16], b_reg[16]; + + __m512i mask1 = _mm512_set_epi32( 0x17, 0x16, 0x15, 0x14, + 0x07, 0x06, 0x05, 0x04, + 0x13, 0x12, 0x11, 0x10, + 0x03, 0x02, 0x01, 0x00 ); + + __m512i mask2 = _mm512_set_epi32( 0x1F, 0x1E, 0x1D, 0x1C, + 0x0F, 0x0E, 0x0D, 0x0C, + 0x1B, 0x1A, 0x19, 0x18, + 0x0B, 0x0A, 0x9, 0x08 ); + + __m512i mask3 = _mm512_set_epi32( 0x17, 0x16, 0x15, 0x14, 0x13, 0x12, 0x11, 0x10, + 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00 ); + __m512i mask4 = _mm512_set_epi32( 0x1F, 0x1E, 0x1D, 0x1C, 0x1B, 0x1A, 0x19, 0x18, + 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 ); + + // These registers are set with zeroes to avoid compiler warnings + // To-DO: TO be removed when pack code is optimized for fringe cases. + a_reg[0] = _mm512_setzero_ps(); + a_reg[1] = _mm512_setzero_ps(); + a_reg[2] = _mm512_setzero_ps(); + a_reg[3] = _mm512_setzero_ps(); + a_reg[4] = _mm512_setzero_ps(); + a_reg[5] = _mm512_setzero_ps(); + a_reg[6] = _mm512_setzero_ps(); + a_reg[7] = _mm512_setzero_ps(); + a_reg[8] = _mm512_setzero_ps(); + a_reg[9] = _mm512_setzero_ps(); + a_reg[10] = _mm512_setzero_ps(); + a_reg[11] = _mm512_setzero_ps(); + a_reg[12] = _mm512_setzero_ps(); + a_reg[13] = _mm512_setzero_ps(); + a_reg[14] = _mm512_setzero_ps(); + a_reg[15] = _mm512_setzero_ps(); + + for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 8 ) * cs_a ) ) ); + a_reg[9] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + ( ( kr + 9 ) * cs_a ) ) ); + a_reg[10] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 10 ) * cs_a ) ) ); + a_reg[11] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 11 ) * cs_a ) ) ); + a_reg[12] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 12 ) * cs_a ) ) ); + a_reg[13] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 13 ) * cs_a ) ) ); + a_reg[14] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 14 ) * cs_a ) ) ); + a_reg[15] = _mm512_loadu_ps( (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 15 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ); + PERMUTE8x8( mask3, mask4 ) + + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[3] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), a_reg[4] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), a_reg[5] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), a_reg[6] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), a_reg[7] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 8 ) * KC + kr ), a_reg[8] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 9 ) * KC + kr ), a_reg[9] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 10 ) * KC + kr ), a_reg[10] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 11 ) * KC + kr ), a_reg[11] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 12 ) * KC + kr ), a_reg[12] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 13 ) * KC + kr ), a_reg[13] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 14 ) * KC + kr ), a_reg[14] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 15 ) * KC + kr ), a_reg[15] ); + } + for ( ; ( kr + 7 ) < KC; kr += 8 ) + { + a_reg[0] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8(mask3, mask4) + MASKED_STORE_PS(0xFF); + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + MASKED_STORE_PS(0x0F); + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_loadu_ps( (__m512 const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8(mask3, mask4) + MASKED_STORE_PS(0x03); + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm512_loadu_ps( (__m512 const *)(a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8(mask3, mask4) + MASKED_STORE_PS(0x01); + } + } + for( ; (ic + 8 - 1) < MC; ic += 8) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 8 ) * cs_a ) ) ); + a_reg[9] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 9 ) * cs_a ) ) ); + a_reg[10] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 10 ) * cs_a ) ) ); + a_reg[11] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 11 ) * cs_a ) ) ); + a_reg[12] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 12 ) * cs_a ) ) ); + a_reg[13] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 13 ) * cs_a ) ) ); + a_reg[14] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 14 ) * cs_a ) ) ); + a_reg[15] = _mm512_maskz_loadu_ps( 0xFF, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 15 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[3] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), a_reg[4] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), a_reg[5] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), a_reg[6] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), a_reg[7] ); + } + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4(mask1, mask2) + PERMUTE8x8(mask3, mask4) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0xFF, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0xFF, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0xFF, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0xFF, a_reg[3] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0xFF, a_reg[4] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0xFF, a_reg[5] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0xFF, a_reg[6] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0xFF, a_reg[7] ); + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x0F, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x0F, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x0F, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x0F, a_reg[3] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x0F, a_reg[4] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x0F, a_reg[5] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x0F, a_reg[6] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x0F, a_reg[7] ); + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x03, a_reg[3] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x03, a_reg[4] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x03, a_reg[5] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x03, a_reg[6] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x03, a_reg[7] ); + + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[3] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x01, a_reg[4] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x01, a_reg[5] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x01, a_reg[6] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x01, a_reg[7] ); + } + } + for( ; ( ic + 4 - 1 ) < MC; ic += 4) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 8 ) * cs_a ) ) ); + a_reg[9] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 9 ) * cs_a ) ) ); + a_reg[10] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 10 ) * cs_a ) ) ); + a_reg[11] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 11 ) * cs_a ) ) ); + a_reg[12] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 12 ) * cs_a ) ) ); + a_reg[13] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 13 ) * cs_a ) ) ); + a_reg[14] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 14 ) * cs_a ) ) ); + a_reg[15] = _mm512_maskz_loadu_ps ( 0x0F, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 15 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[3] ); + } + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4(mask1, mask2) + PERMUTE8x8(mask3, mask4) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0xFF, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0xFF, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0xFF, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0xFF, a_reg[3] ); + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x0F, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x0F, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x0F, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x0F, a_reg[3] ); + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x03, a_reg[3] ); + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = (__m512)_mm512_maskz_loadu_ps( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm512_mask_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[3] ); + } + } + if( m_left ) { + __mmask16 mask = 0xFFFF >> ( 16 - m_left ); + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 8 ) * cs_a ) ) ); + a_reg[9] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 9 ) * cs_a ) ) ); + a_reg[10] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 10 ) * cs_a ) ) ); + a_reg[11] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 11 ) * cs_a ) ) ); + a_reg[12] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 12 ) * cs_a ) ) ); + a_reg[13] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 13 ) * cs_a ) ) ); + a_reg[14] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 14 ) * cs_a ) ) ); + a_reg[15] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 15 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + switch( m_left ) + { + case 3: + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + break; + + case 2: + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[1] ); + break; + + case 1: + _mm512_storeu_ps( (__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + break; + } + } + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 7 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + switch( m_left ) + { + case 3: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0xFF, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0xFF, a_reg[1]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0xFF, a_reg[2]); + break; + + case 2: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0xFF, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0xFF, a_reg[1]); + break; + + case 1: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0xFF, a_reg[0]); + break; + } + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 3 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + switch( m_left ) + { + case 3: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x0F, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x0F, a_reg[1]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x0F, a_reg[2]); + break; + + case 2: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x0F, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x0F, a_reg[1]); + break; + + case 1: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x0F, a_reg[0]); + break; + } + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 1 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + switch( m_left ) + { + case 3: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[1]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2]); + break; + + case 2: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[1]); + break; + + case 1: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0]); + break; + } + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm512_maskz_loadu_ps ( mask, (__m512 const *) ( a + ( ic * rs_a ) + + ( ( kr + 0 ) * cs_a ) ) ); + + UNPACKLO_PS16 + UNPACKHI_PS16 + SHUFFLE_64x2 + PERMUTE4x4( mask1, mask2 ) + PERMUTE8x8( mask3, mask4 ) + + switch( m_left ) + { + case 3: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[1]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2]); + break; + + case 2: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0]); + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[1]); + break; + + case 1: + _mm512_mask_storeu_ps((__m512 *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0]); + break; + } + } + } + *rs_p = KC; + *cs_p = 1; +} +#endif diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c new file mode 100644 index 0000000000..59b934f207 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c @@ -0,0 +1,435 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +LPGEMV_M_EQ1_KERN( float, float, float, f32f32f32of32 ) +{ + static void *post_ops_labels[] = + { + &&POST_OPS_6x64F_DISABLE, + &&POST_OPS_BIAS_6x64F, + &&POST_OPS_RELU_6x64F, + &&POST_OPS_RELU_SCALE_6x64F, + &&POST_OPS_GELU_TANH_6x64F, + &&POST_OPS_GELU_ERF_6x64F, + &&POST_OPS_CLIP_6x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x64F, + &&POST_OPS_SWISH_6x64F, + &&POST_OPS_MATRIX_MUL_6x64F + }; + + // Strides are updated based on matrix packing/reordering. + const float *a_use = NULL; + const float *b_use = NULL; + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for (dim_t jr = 0; jr < n0; jr += NR) + { + dim_t nr0 = bli_min((n0 - jr), NR); + c_use = c + jr; + __mmask16 k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF, k4 = 0xFFFF; + + if (nr0 < NR) + { + __mmask16 k = (0xFFFF >> (16 - (nr0 & 0x0F))); + if (nr0 >= 48) + { + k4 = k; + } + else if (nr0 >= 32) + { + k3 = k; + k4 = 0; + } + else if (nr0 >= 16) + { + k2 = k; + k3 = k4 = 0; + } + else + { + k1 = k; + k2 = k3 = k4 = 0; + } + } + + __m512 zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512 zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512 zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512 zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512 zmm29, zmm30, zmm31; + + // zero the accumulator registers + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + + //Zero out registers used for mask load to avoid warnings + ZERO_ACC_ZMM_4_REG(zmm0, zmm1, zmm2, zmm3); + ZERO_ACC_ZMM_4_REG(zmm24, zmm25, zmm26, zmm27); + ZERO_ACC_ZMM_4_REG(zmm28, zmm29, zmm30, zmm31); + + //_mm_prefetch( (MR X NR) from C + _mm_prefetch((c_use + 0 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 16 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 32 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 64 * rs_c), _MM_HINT_T0); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + uint64_t k_iter = kc0 / 4; + uint64_t k_rem = kc0 % 4; + dim_t ps_b_use = 0; + dim_t rs_b_use = NR; + // No parallelization in k dim, k always starts at 0. + if (mtag_b == REORDERED||mtag_b == PACK) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + (n_sub_updated * pc) + (jc_cur_loop_rem * kc0); + ps_b_use = kc0; + } + else + { + b_use = b + (pc * rs_b); + ps_b_use = 1; + rs_b_use = rs_b; + } + + a_use = a + pc; + b_use = b_use + jr * ps_b_use; + + for (dim_t k = 0; k < k_iter; k++) + { + _mm_prefetch((b_use + 4 * rs_b_use), _MM_HINT_T0); + //Using mask loads to avoid writing fringe kernels + + //Load first 4x16 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_ps(k1, b_use); + zmm1 = _mm512_maskz_loadu_ps(k1, b_use + rs_b_use); + zmm2 = _mm512_maskz_loadu_ps(k1, b_use + 2 * rs_b_use); + zmm3 = _mm512_maskz_loadu_ps(k1, b_use + 3 * rs_b_use); + b_use += 16; + + //Broadcast col0 - col3 element of A + zmm4 = _mm512_set1_ps(*(a_use)); // broadcast c0 + zmm5 = _mm512_set1_ps(*(a_use + 1)); // broadcast c1 + zmm6 = _mm512_set1_ps(*(a_use + 2)); // broadcast c2 + zmm7 = _mm512_set1_ps(*(a_use + 3)); // broadcast c3 + + //Load second 4x16 tile from row 0-3 + zmm24 = _mm512_maskz_loadu_ps(k2, b_use); + zmm25 = _mm512_maskz_loadu_ps(k2, b_use + rs_b_use); + zmm26 = _mm512_maskz_loadu_ps(k2, b_use + 2 * rs_b_use); + zmm27 = _mm512_maskz_loadu_ps(k2, b_use + 3 * rs_b_use); + b_use += 16; + + zmm8 = _mm512_fmadd_ps(zmm0, zmm4, zmm8); + zmm9 = _mm512_fmadd_ps(zmm1, zmm5, zmm9); + zmm10 = _mm512_fmadd_ps(zmm2, zmm6, zmm10); + zmm11 = _mm512_fmadd_ps(zmm3, zmm7, zmm11); + + //Load third 4x16 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_ps(k3, b_use); + zmm1 = _mm512_maskz_loadu_ps(k3, b_use + rs_b_use); + zmm2 = _mm512_maskz_loadu_ps(k3, b_use + 2 * rs_b_use); + zmm3 = _mm512_maskz_loadu_ps(k3, b_use + 3 * rs_b_use); + b_use += 16; + + zmm12 = _mm512_fmadd_ps(zmm24, zmm4, zmm12); + zmm13 = _mm512_fmadd_ps(zmm25, zmm5, zmm13); + zmm14 = _mm512_fmadd_ps(zmm26, zmm6, zmm14); + zmm15 = _mm512_fmadd_ps(zmm27, zmm7, zmm15); + + //Load fourth 4x16 tile from row 0-3 + zmm28 = _mm512_maskz_loadu_ps(k4, b_use); + zmm29 = _mm512_maskz_loadu_ps(k4, b_use + rs_b_use); + zmm30 = _mm512_maskz_loadu_ps(k4, b_use + 2 * rs_b_use); + zmm31 = _mm512_maskz_loadu_ps(k4, b_use + 3 * rs_b_use); + + zmm16 = _mm512_fmadd_ps(zmm0, zmm4, zmm16); + zmm17 = _mm512_fmadd_ps(zmm1, zmm5, zmm17); + zmm18 = _mm512_fmadd_ps(zmm2, zmm6, zmm18); + zmm19 = _mm512_fmadd_ps(zmm3, zmm7, zmm19); + + zmm20 = _mm512_fmadd_ps(zmm28, zmm4, zmm20); + zmm21 = _mm512_fmadd_ps(zmm29, zmm5, zmm21); + zmm22 = _mm512_fmadd_ps(zmm30, zmm6, zmm22); + zmm23 = _mm512_fmadd_ps(zmm31, zmm7, zmm23); + + b_use -= 48; // move b point back to start of KCXNR + b_use += (4 * rs_b_use); + a_use += 4; // move a pointer to next col + } // kloop + + for (dim_t kr = 0; kr < k_rem; kr++) + { + //Load 64 elements from a row of B + zmm0 = _mm512_maskz_loadu_ps(k1, b_use); + zmm1 = _mm512_maskz_loadu_ps(k2, b_use + 16); + zmm2 = _mm512_maskz_loadu_ps(k3, b_use + 32); + zmm3 = _mm512_maskz_loadu_ps(k4, b_use + 48); + + //Broadcast col0 elements of 12 rows of A + zmm4 = _mm512_set1_ps(*(a_use)); // broadcast c0r0 + + zmm8 = _mm512_fmadd_ps(zmm0, zmm4, zmm8); + zmm12 = _mm512_fmadd_ps(zmm1, zmm4, zmm12); + zmm16 = _mm512_fmadd_ps(zmm2, zmm4, zmm16); + zmm20 = _mm512_fmadd_ps(zmm3, zmm4, zmm20); + + b_use += rs_b_use; // move b pointer to next row + a_use++; // move a pointer to next col + } // kloop + } // kc loop + + //SUMUP K untoll output + zmm8 = _mm512_add_ps(zmm9, zmm8); + zmm10 = _mm512_add_ps(zmm11, zmm10); + zmm8 = _mm512_add_ps(zmm10, zmm8); // 16 outputs + + zmm12 = _mm512_add_ps(zmm13, zmm12); + zmm14 = _mm512_add_ps(zmm15, zmm14); + zmm12 = _mm512_add_ps(zmm14, zmm12); // 16 outputs + + zmm16 = _mm512_add_ps(zmm17, zmm16); + zmm18 = _mm512_add_ps(zmm19, zmm18); + zmm16 = _mm512_add_ps(zmm18, zmm16); // 16 outputs + + zmm20 = _mm512_add_ps(zmm21, zmm20); + zmm22 = _mm512_add_ps(zmm23, zmm22); + zmm20 = _mm512_add_ps(zmm22, zmm20); // 16 outputs + + //Mulitply A*B output with alpha + zmm0 = _mm512_set1_ps(alpha); + zmm8 = _mm512_mul_ps(zmm0, zmm8); + zmm12 = _mm512_mul_ps(zmm0, zmm12); + zmm16 = _mm512_mul_ps(zmm0, zmm16); + zmm20 = _mm512_mul_ps(zmm0, zmm20); + + if (beta != 0) + { + const float *_cbuf = c_use; + // load c and multiply with beta and + // add to accumulator and store back + zmm3 = _mm512_set1_ps(beta); + zmm0 = _mm512_maskz_loadu_ps(k1, _cbuf); + zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); + + zmm1 = _mm512_maskz_loadu_ps(k2, (_cbuf + 16)); + zmm12 = _mm512_fmadd_ps(zmm1, zmm3, zmm12); + + zmm2 = _mm512_maskz_loadu_ps(k3, (_cbuf + 32)); + zmm16 = _mm512_fmadd_ps(zmm2, zmm3, zmm16); + + zmm4 = _mm512_maskz_loadu_ps(k4, (_cbuf + 48)); + zmm20 = _mm512_fmadd_ps(zmm4, zmm3, zmm20); + } + + // Post Ops + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64F: + { + if ((*(char *)post_ops_list_temp->op_args2 == 'r') || + (*(char *)post_ops_list_temp->op_args2 == 'R')) + { + float* bias_ptr = (float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j; + zmm9 = _mm512_maskz_loadu_ps(k1, bias_ptr + (0 * 16)); + + zmm10 = _mm512_maskz_loadu_ps(k2, bias_ptr + (1 * 16)); + + zmm13 = _mm512_maskz_loadu_ps(k3, bias_ptr + (2 * 16)); + + zmm14 = _mm512_maskz_loadu_ps(k4, bias_ptr + (3 * 16)); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + float bias = (*((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0)); + + zmm9 = _mm512_set1_ps(bias); + zmm10 = zmm13 = zmm14 = zmm9; + } + // c[0,0-15] + zmm8 = _mm512_add_ps(zmm9, zmm8); + zmm12 = _mm512_add_ps(zmm10, zmm12); + zmm16 = _mm512_add_ps(zmm13, zmm16); + zmm20 = _mm512_add_ps(zmm14, zmm20); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64F: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps(zmm1, zmm8); + zmm12 = _mm512_max_ps(zmm1, zmm12); + zmm16 = _mm512_max_ps(zmm1, zmm16); + zmm20 = _mm512_max_ps(zmm1, zmm20); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64F: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps(*((float *)post_ops_list_temp->op_args2)); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + RELU_SCALE_OP_F32S_AVX512(zmm12) + RELU_SCALE_OP_F32S_AVX512(zmm16) + RELU_SCALE_OP_F32S_AVX512(zmm20) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64F: + { + __m512i zmm6; + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm12, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm16, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm20, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64F: + { + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm12, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm16, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm20, zmm0, zmm1, zmm2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64F: + { + zmm0 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args2); + zmm1 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args3); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, zmm0, zmm1) + CLIP_F32S_AVX512(zmm12, zmm0, zmm1) + CLIP_F32S_AVX512(zmm16, zmm0, zmm1) + CLIP_F32S_AVX512(zmm20, zmm0, zmm1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k1, (matptr + post_ops_attr.post_op_c_j)); + zmm8 = _mm512_add_ps(zmm8, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_j + 16)); + zmm12 = _mm512_add_ps(zmm12, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k3, (matptr + post_ops_attr.post_op_c_j + 32)); + zmm16 = _mm512_add_ps(zmm16, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k4, (matptr + post_ops_attr.post_op_c_j + 48)); + zmm20 = _mm512_add_ps(zmm20, zmm0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_MUL_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k1, (matptr + post_ops_attr.post_op_c_j)); + zmm8 = _mm512_mul_ps(zmm8, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_j + 16)); + zmm12 = _mm512_mul_ps(zmm12, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k3, (matptr + post_ops_attr.post_op_c_j + 32)); + zmm16 = _mm512_mul_ps(zmm16, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k4, (matptr + post_ops_attr.post_op_c_j + 48)); + zmm20 = _mm512_mul_ps(zmm20, zmm0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_SWISH_6x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[1, 0-15] + SWISH_F32_AVX512_DEF(zmm12, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[2, 0-15] + SWISH_F32_AVX512_DEF(zmm16, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + // c[3, 0-15] + SWISH_F32_AVX512_DEF(zmm20, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64F_DISABLE: + { + _mm512_mask_storeu_ps(c_use, k1, zmm8); + _mm512_mask_storeu_ps((c_use + 16), k2, zmm12); + _mm512_mask_storeu_ps((c_use + 32), k3, zmm16); + _mm512_mask_storeu_ps((c_use + 48), k4, zmm20); + post_ops_attr.post_op_c_j += NR; + } + } // jr loop +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c new file mode 100644 index 0000000000..2e1576995d --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c @@ -0,0 +1,512 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +#define LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, paddr, stride) \ + zmm0 = _mm512_loadu_ps(paddr); \ + zmm1 = _mm512_loadu_ps(paddr + stride); \ + zmm2 = _mm512_loadu_ps(paddr + 2 * stride); \ + zmm3 = _mm512_loadu_ps(paddr + 3 * stride); + +#define LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, paddr, stride) \ + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, paddr); \ + zmm1 = _mm512_mask_loadu_ps(zmm7, k1, paddr + stride); \ + zmm2 = _mm512_mask_loadu_ps(zmm7, k1, paddr + 2 * stride); \ + zmm3 = _mm512_mask_loadu_ps(zmm7, k1, paddr + 3 * stride); + +#define LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) \ + zmm8 = _mm512_fmadd_ps(zmm0, zmm6, zmm8); \ + zmm9 = _mm512_fmadd_ps(zmm1, zmm6, zmm9); \ + zmm10 = _mm512_fmadd_ps(zmm2, zmm6, zmm10); \ + zmm11 = _mm512_fmadd_ps(zmm3, zmm6, zmm11); + +#define LPGEMV_ZMM2XMM(zmm0, zmm1, zmm2, zmm3, ymm0, ymm1, ymm2, ymm3, xmm0) \ + ymm0 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm0, 0x0), \ + _mm512_extractf32x8_ps(zmm0, 0x1)); \ + ymm1 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm1, 0x0), \ + _mm512_extractf32x8_ps(zmm1, 0x1)); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + ymm2 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm2, 0x0), \ + _mm512_extractf32x8_ps(zmm2, 0x1)); \ + ymm3 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm3, 0x0), \ + _mm512_extractf32x8_ps(zmm3, 0x1)); \ + ymm1 = _mm256_hadd_ps(ymm2, ymm3); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + xmm0 = _mm_add_ps(_mm256_extractf128_ps(ymm0, 0), _mm256_extractf128_ps(ymm0,1)); + +// When n=1 is load 16x1 from B and load MRx16 from A and perform dot product +// to produce C output of MRX1. The vectorization is done in k loop and +// the horizontal reduction done to produce one output from each +// accumulator register +LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 ) +{ + static void *post_ops_labels[] = + { + &&POST_OPS_6x64F_DISABLE, + &&POST_OPS_BIAS_6x64F, + &&POST_OPS_RELU_6x64F, + &&POST_OPS_RELU_SCALE_6x64F, + &&POST_OPS_GELU_TANH_6x64F, + &&POST_OPS_GELU_ERF_6x64F, + &&POST_OPS_CLIP_6x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x64F, + &&POST_OPS_SWISH_6x64F, + &&POST_OPS_MATRIX_MUL_6x64F + }; + + // Strides are updated based on matrix packing/reordering. + const float *a_use = NULL; + const float *b_use = NULL; + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for (dim_t mr = 0; mr < m0; mr += MR) + { + dim_t mr0 = bli_min((m0 - mr), MR); + dim_t k_iter = k/16; + dim_t k_rem = k & 0xF; + + //Create load mask for k fringe + __mmask16 k1 = 0xFFFF; + if (k_rem) + { + k1 = (0xFFFF >> (16 - k_rem)); + } + + // Create store mask for C for mr fringe + __mmask16 k2 = 0xFFFF; + if (mr0 < MR) + { + k2 = (0xFFFF >> (MR - mr0)); + } + + __m512 zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512 zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512 zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512 zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512 zmm29, zmm30, zmm31; + + __m256 ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6; + __m128 xmm0, xmm1, xmm2, xmm3; + + ZERO_ACC_ZMM_4_REG(zmm0, zmm1, zmm2, zmm3); + ZERO_ACC_ZMM_4_REG(zmm4, zmm5, zmm6, zmm7); + /* zero the accumulator registers */ + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + ZERO_ACC_ZMM_4_REG(zmm24, zmm25, zmm26, zmm27); + ZERO_ACC_ZMM_4_REG(zmm28, zmm29, zmm30, zmm31); + ZERO_ACC_XMM_4_REG (xmm0,xmm1,xmm2,xmm3) + + //update pointers + a_use = a + mr * rs_a; + b_use = b; + c_use = c + mr * rs_c; + + //prefetch C + _mm_prefetch(c_use, _MM_HINT_T0); + _mm_prefetch(b_use, _MM_HINT_T0); + + //Check for MR whether to process main kernel or mfringe kernel + if (mr0 == MR) + { + //Dot product kernel + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + + //Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS(zmm24, zmm25, zmm26, zmm27, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + + // Load 4x16 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS(zmm28, zmm29, zmm30, zmm31, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use -= (12 * rs_a); //Update aptr back to move horizontally + + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm28, zmm29, zmm30, zmm31) + LPGEMV_N_KERNEL_4_FMA(zmm20, zmm21, zmm22, zmm23, zmm6, zmm0, zmm1, zmm2, zmm3) + a_use += 16; + }// kloop + + if(k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm24, zmm25, zmm26, zmm27, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm28, zmm29, zmm30, zmm31, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm28, zmm29, zmm30, zmm31) + LPGEMV_N_KERNEL_4_FMA(zmm20, zmm21, zmm22, zmm23, zmm6, zmm0, zmm1, zmm2, zmm3) + }// kloop + + //Add the registers horizantally to get one + LPGEMV_ZMM2XMM(zmm8, zmm9, zmm10, zmm11, ymm0, ymm1, ymm2, ymm3, xmm0) + LPGEMV_ZMM2XMM(zmm12, zmm13, zmm14, zmm15, ymm4, ymm1, ymm2, ymm3, xmm1) + LPGEMV_ZMM2XMM(zmm16, zmm17, zmm18, zmm19, ymm5, ymm1, ymm2, ymm3, xmm2) + LPGEMV_ZMM2XMM(zmm20, zmm21, zmm22, zmm23, ymm6, ymm1, ymm2, ymm3, xmm3) + + //compose outputs into one zmm to perform post-ops + zmm8 = _mm512_insertf32x4(zmm8, xmm0, 0); + zmm8 = _mm512_insertf32x4(zmm8, xmm1, 1); + zmm8 = _mm512_insertf32x4(zmm8, xmm2, 2); + zmm8 = _mm512_insertf32x4(zmm8, xmm3, 3); + }else + { + //Handle fringe cases when mr0 < MR + const float *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + // Dot product for mfringe 8 + if (mr0_use >= 8) + { + // Dot product kernel for mr0 == 8 + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS(zmm24, zmm25, zmm26, zmm27, a_use, rs_a) + a_use -= (4 * rs_a); + + //Perform FMA on two 4x16 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + LPGEMV_N_KERNEL_4_MASKLOADS(zmm24, zmm25, zmm26, zmm27, zmm7, k1, a_use, rs_a) + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + } + + //update pointers + mr0_use -= 8; + a_use = a_use_fringe + 8 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 8 zmm registers and get output into 2 xmm registers + LPGEMV_ZMM2XMM(zmm8, zmm9, zmm10, zmm11, ymm0, ymm1, ymm2, ymm3, xmm0) + LPGEMV_ZMM2XMM(zmm12, zmm13, zmm14, zmm15, ymm4, ymm1, ymm2, ymm3, xmm1) + + //insert xmm outputs into final output zmm8 reg + zmm8 = _mm512_insertf32x4(zmm8, xmm0, 0); + zmm8 = _mm512_insertf32x4(zmm8, xmm1, 1); + regidx = 2; + } + + // Dot product for mfringe 4 + if (mr0_use >= 4) + { + // Dot product kernel for mr0 == 8 + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + // Perform FMA on 4x16 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm0, zmm1, zmm2, zmm3) + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm0, zmm1, zmm2, zmm3) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM(zmm16, zmm17, zmm18, zmm19, ymm5, ymm1, ymm2, ymm3, xmm2) + + //insert xmm outputs into final output zmm8 reg based on regidx + if(regidx == 0) zmm8 = _mm512_insertf32x4(zmm8, xmm2, 0); + else zmm8 = _mm512_insertf32x4(zmm8, xmm2, 2); + regidx++; + } + + // Dot product for <= 3 + if (mr0_use) + { + // Dot product for m = 2 + if (mr0_use >= 2) + { + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + // Load 2x16 elements from row0-row1 of A + zmm0 = _mm512_loadu_ps(a_use); + zmm1 = _mm512_loadu_ps(a_use + rs_a); + zmm20 = _mm512_fmadd_ps(zmm0, zmm6, zmm20); + zmm21 = _mm512_fmadd_ps(zmm1, zmm6, zmm21); + b_use += 16; // move b pointer to next 16 elements + a_use += 16; + } + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, a_use); // Load 0-15 in b[k+0 - k+15] + zmm1 = _mm512_mask_loadu_ps(zmm7, k1, a_use + rs_a); // Load 0-15 in b[k+0 - k+15] + zmm20 = _mm512_fmadd_ps(zmm0, zmm6, zmm20); + zmm21 = _mm512_fmadd_ps(zmm1, zmm6, zmm21); + } + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = b; + } + + // Dot product for m = 2 + if (mr0_use == 1) + { + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + zmm0 = _mm512_loadu_ps(a_use); + zmm22 = _mm512_fmadd_ps(zmm0, zmm6, zmm22); + b_use += 16; // move b pointer to next 16 elements + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, a_use); + zmm22 = _mm512_fmadd_ps(zmm0, zmm6, zmm22); + } + // When only fringe 1, update the registers to store in order + if (!(mr0 & 0x2)) zmm20 = zmm22; + } + + // Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM(zmm20, zmm21, zmm22, zmm23, ymm6, ymm1, ymm2, ymm3, xmm3) + + // insert xmm outputs into final output zmm8 reg based on regidx + if (regidx == 0) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 0); + else if(regidx == 1) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 1); + else if (regidx == 2) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 2); + else zmm8 = _mm512_insertf32x4(zmm8, xmm3, 3); + } + } + + //Scale accumulated output with alpha + zmm0 = _mm512_set1_ps(alpha); + zmm8 = _mm512_mul_ps(zmm0, zmm8); + + if (beta != 0) + { + const float *_cbuf = c_use; + + //C = beta*C + alpha*A*B + zmm3 = _mm512_set1_ps(beta); + if (rs_c == 1) + { + zmm0 = _mm512_maskz_loadu_ps(k2, _cbuf); + }else + { + //load C into zmm0 + float ctemp[16]; + for(dim_t i = 0; i < mr0; i++) + { + ctemp[i] = _cbuf[i * rs_c]; + } + zmm0 = _mm512_maskz_loadu_ps(k2, ctemp); + } + zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); + } + + // Post Ops + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64F: + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm9 = _mm512_set1_ps(*((float *)post_ops_list_temp->op_args1)); + zmm8 = _mm512_add_ps(zmm9, zmm8); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64F: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps(zmm1, zmm8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64F: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps(*((float *)post_ops_list_temp->op_args2)); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64F: + { + __m512i zmm6; + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64F: + { + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, zmm0, zmm1, zmm2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64F: + { + zmm0 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args2); + zmm1 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args3); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, zmm0, zmm1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_i)); + zmm8 = _mm512_add_ps(zmm8, zmm0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_MUL_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_i)); + zmm8 = _mm512_mul_ps(zmm8, zmm0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_SWISH_6x64F: + { + zmm7 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + __m512i ex_out; + + // c[0, 0-15] + SWISH_F32_AVX512_DEF(zmm8, zmm7, zmm0, zmm1, zmm2, zmm3, zmm4, ex_out); + } + POST_OPS_6x64F_DISABLE: + { + if (rs_c == 1) + { + _mm512_mask_storeu_ps(c_use, k2, zmm8); + } + else + { + // Store ZMM8 into ctemp buffer and store back + // element by element into output buffer at strides + float ctemp[16]; + _mm512_mask_storeu_ps(ctemp, k2, zmm8); + for (dim_t i = 0; i < mr0; i++) + { + c_use[i * rs_c] = ctemp[i]; + } + } + post_ops_attr.post_op_c_i += MR; + } + } // mr loop +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/gelu_avx512.h b/kernels/zen4/lpgemm/gelu_avx512.h index 814f136f50..868c9cca67 100644 --- a/kernels/zen4/lpgemm/gelu_avx512.h +++ b/kernels/zen4/lpgemm/gelu_avx512.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen4/lpgemm/int4_utils_avx512.h b/kernels/zen4/lpgemm/int4_utils_avx512.h new file mode 100644 index 0000000000..a9c08435f8 --- /dev/null +++ b/kernels/zen4/lpgemm/int4_utils_avx512.h @@ -0,0 +1,411 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_INT4_CVT_UTILS_H +#define LPGEMM_INT4_CVT_UTILS_H + +/* shift_idx:__m512i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm512_set1_epi64( 0x1C1814100C080400lu ); + +/* shift_idx:__m256i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm256_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ + 0x1C1814100C080400lu ); + +/* shift_idx:__m128i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ + 0x1C1814100C080400lu ); + +/* input:__m256i, output: __m512i*/ +#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm512_multishift_epi64_epi8( shift_idx, \ + _mm512_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm512_set1_epi8( 0x0F ) ); + +/* input:__m256i, output: __m512i*/ +#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + __m512i upscale_input = _mm512_cvtepu32_epi64( input_0 ); \ + __m512i shift_input = _mm512_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm512_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm512_permutex2var_epi8( output, conv_shift, shift_input ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm512_set1_epi8( 0x0F ) ); + +/* input:__m128i, output: __m256i*/ +#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm256_multishift_epi64_epi8( shift_idx, \ + _mm256_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm256_set1_epi8( 0x0F ) ); + +/* input:__m128i, output: __m256i*/ +#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + __m256i upscale_input = _mm256_cvtepu32_epi64( input_0 ); \ + __m256i shift_input = _mm256_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm256_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm256_permutex2var_epi8( output, conv_shift, shift_input ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm256_set1_epi8( 0x0F ) ); + +/* input:int64_t, output: __m128i*/ +#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm_multishift_epi64_epi8( shift_idx, \ + _mm_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm_set1_epi8( 0x0F ) ); + +/* input:int64_t, output:__m128i*/ +#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + input_0 = _mm_cvtepu32_epi64( input_0 ); \ + input_1 = _mm_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm_multishift_epi64_epi8( odd_shift_idx, input_0 ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm_permutex2var_epi8( output, conv_shift, input_1 ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm_set1_epi8( 0x0F ) ); + +#define SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m512i hi_bits_512 = _mm512_and_epi32( output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_512 = _mm512_xor_epi32( hi_bits_512, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_512 = _mm512_add_epi8( hi_bits_512, _mm512_set1_epi8( 0xF8 ) ); \ + \ + output = _mm512_or_epi32( output, hi_bits_512 ); + +#define SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m256i hi_bits_256 = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ + output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_256 = _mm256_xor_epi32( hi_bits_256, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_256 = _mm256_add_epi8( hi_bits_256, _mm256_set1_epi8( 0xF8 ) ); \ + \ + output = _mm256_or_epi32( output, hi_bits_256 ); + +#define SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m128i hi_bits_128 = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ + output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_128 = _mm_xor_epi32( hi_bits_128, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_128 = _mm_add_epi8( hi_bits_128, _mm_set1_epi8( 0xF8 ) ); \ + \ + output = _mm_or_epi32( output, hi_bits_128 ); + +/* input:__m256i, output: __m512i*/ +#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(var_name) \ + const int64_t var_name[8] = { \ + 0x0807060504030201, 0x100F0E0D0C0B0A09, \ + 0X1817161514131211, 0X201F1E1D1C1B1A19, \ + 0X2827262524232221, 0X302F2E2D2C2B2A29, \ + 0X3837363534333231, 0X7B3F3E3D3C3B3A39 }; + +/* input:__m256i, output: __m512i*/ +#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ + } \ +} while (0); + +/* input:__m128i, output: __m256i*/ +#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(var_name) \ + const int64_t var_name[4] = { \ + 0x0807060504030201, 0x100F0E0D0C0B0A09, \ + 0X1817161514131211, 0X3B1F1E1D1C1B1A19 }; + +/* input:__m128i, output: __m256i*/ +#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ + } \ +} while (0); + +/* input:int64_t, output: __m128i*/ +#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(var_name) \ + const int64_t var_name[2] = { \ + 0x0807060504030201, 0x1B0F0E0D0C0B0A09 }; + +/* input:int64_t, output: __m128i*/ +#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(var_name) \ + int8_t var_name[64] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \ + 0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \ + 0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E, \ + 0x40, 0x42, 0x44, 0x46, 0x48, 0x4A, 0x4C, 0x4E, \ + 0x50, 0x52, 0x54, 0x56, 0x58, 0x5A, 0x5C, 0x5E, \ + 0x60, 0x62, 0x64, 0x66, 0x68, 0x6A, 0x6C, 0x6E, \ + 0x70, 0x72, 0x74, 0x76, 0x78, 0x7A, 0x7C, 0x7E}; + +/* Conversion from int8 to int4. First split the elements in __m512i + * register at even indices and odd indices into two separate __m256i + * even and odd registers. Then shift the elements in odd by 4 to the + * left and OR with even register. */ +/* input_*:__m512i, output: __m512i */ +#define CVT_INT8_INT4_64ELEM_2_ZMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm512_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m512i odd_out = _mm512_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm512_and_epi32( output, clear_hi_bits ); \ + \ + __m256i odd1_256 = _mm512_extracti64x4_epi64( odd_out, 0x0 ); \ + __m256i odd2_256 = _mm512_extracti64x4_epi64( odd_out, 0x1 ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + odd1_256 = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd1_256 ), 0x4 ) ); \ + odd2_256 = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd2_256 ), 0x4 ) ); \ + \ + odd_out = _mm512_castsi256_si512( odd1_256 ); \ + odd_out = _mm512_inserti64x4( odd_out, odd2_256, 0x01 ); \ + \ + output = _mm512_or_epi32( output, odd_out ); \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(var_name) \ + int8_t var_name[32] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \ + 0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \ + 0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E}; + +/* input_*:__m256i, output: __m256i */ +#define CVT_INT8_INT4_32ELEM_2_YMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm256_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m256i odd_out = _mm256_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \ + output, clear_hi_bits ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + odd_out = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd_out ), 0x4 ) ); \ + \ + output = _mm256_or_epi32( output, odd_out ); \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(var_name) \ + int8_t var_name[16] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E}; + +/* input_*:__m128i, output: __m128i */ +#define CVT_INT8_INT4_16ELEM_2_XMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m128i odd_out = _mm_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \ + output, clear_hi_bits ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + __mmask16 sel_all_mask = _cvtu32_mask16( 0xFFFF ); \ + odd_out = _mm256_maskz_cvtepi16_epi8( sel_all_mask, \ + _mm256_maskz_slli_epi16( sel_all_mask, \ + _mm256_maskz_cvtepu8_epi16( sel_all_mask, odd_out ), 0x4 ) ); \ + \ + output = _mm_or_epi32( output, odd_out ); \ +} while (0); + + +#define CVT_INT8_F32_SCAL_16( in, idx, scale_reg) \ + (_mm512_mul_ps( \ + _mm512_cvtepi32_ps( \ + _mm512_cvtepi8_epi32( \ + _mm512_extracti32x4_epi32( in, idx ) ) ), scale_reg ) ) + +#define CVT_INT8_F32_SCAL_8( in, idx, scale_reg) \ + (_mm512_mul_ps( \ + _mm512_cvtepi32_ps( \ + _mm512_cvtepi8_epi32( \ + _mm256_extracti32x4_epi32( in, idx ) ) ), scale_reg ) ) + + +#endif //LPGEMM_INT4_CVT_UTILS_H diff --git a/kernels/zen4/lpgemm/lpgemm_util_l1_ops_avx512.c b/kernels/zen4/lpgemm/lpgemm_util_l1_ops_avx512.c index 36ad94569d..2a190757b5 100644 --- a/kernels/zen4/lpgemm/lpgemm_util_l1_ops_avx512.c +++ b/kernels/zen4/lpgemm/lpgemm_util_l1_ops_avx512.c @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen4/lpgemm/math_utils_avx512.h b/kernels/zen4/lpgemm/math_utils_avx512.h index dddfd58825..5916d02523 100644 --- a/kernels/zen4/lpgemm/math_utils_avx512.h +++ b/kernels/zen4/lpgemm/math_utils_avx512.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -76,6 +76,7 @@ r2 = _mm512_mul_ps (r2, r2); \ r = _mm512_fmadd_ps (r2, _mm512_fmadd_ps (r, _mm512_set1_ps(lpgemm_exp_c5), _mm512_set1_ps(lpgemm_exp_c4)), z); \ +// Require in and out registers to be different. x : in, q : out. #define EXPF_AVX512(x, r, r2, z, dn, q) \ z = _mm512_mul_ps (x, _mm512_set1_ps(TBL_LN2)); \ dn = _mm512_add_ps (z , _mm512_set1_ps(EXPF_HUGE)); \ diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c index df5d29472c..d68ffe3232 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,7 +52,9 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64) &&POST_OPS_GELU_TANH_6x64, &&POST_OPS_GELU_ERF_6x64, &&POST_OPS_CLIP_6x64, - &&POST_OPS_DOWNSCALE_6x64 + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 }; dim_t MR = 6; @@ -156,14 +158,14 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64) } // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -1047,36 +1049,66 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64) POST_OPS_DOWNSCALE_6x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); // int8_t zero point value. - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1150,6 +1182,138 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64) // c[5, 48-63] CVT_MULRND_CVT32(c_int32_5p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + + // c[5:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + + // c[5:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 48-63] + SWISH_S32_AVX512(c_int32_4p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 32-47] + SWISH_S32_AVX512(c_int32_5p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 48-63] + SWISH_S32_AVX512(c_int32_5p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x64_DISABLE: diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c index 53a0f51d17..44038a229b 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,20 +53,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64) &&POST_OPS_GELU_TANH_5x64, &&POST_OPS_GELU_ERF_5x64, &&POST_OPS_CLIP_5x64, - &&POST_OPS_DOWNSCALE_5x64 + &&POST_OPS_DOWNSCALE_5x64, + &&POST_OPS_MATRIX_ADD_5x64, + &&POST_OPS_SWISH_5x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -810,37 +812,68 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -902,6 +935,120 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64) // c[4, 48-63] CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 48-63] + SWISH_S32_AVX512(c_int32_4p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x64_DISABLE: @@ -1052,20 +1199,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64) &&POST_OPS_GELU_TANH_4x64, &&POST_OPS_GELU_ERF_4x64, &&POST_OPS_CLIP_4x64, - &&POST_OPS_DOWNSCALE_4x64 + &&POST_OPS_DOWNSCALE_4x64, + &&POST_OPS_MATRIX_ADD_4x64, + &&POST_OPS_SWISH_4x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -1090,20 +1239,30 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64) __m512i c_int32_3p1 = _mm512_setzero_epi32(); __m512i c_int32_3p2 = _mm512_setzero_epi32(); __m512i c_int32_3p3 = _mm512_setzero_epi32(); - + // gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - //convert signed int8 to uint8 for VNNI + b0 = _mm512_shuffle_epi8(b0, dsmask); + // convert signed int8 to uint8 for VNNI a_int32_0 = _mm512_add_epi8( a_int32_0, vec_uint8 ); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b1 = _mm512_shuffle_epi8(b1, dsmask); b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b2 = _mm512_shuffle_epi8(b2, dsmask); b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b3 = _mm512_shuffle_epi8(b3, dsmask); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -1685,37 +1844,68 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1765,6 +1955,102 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64) // c[3, 48-63] CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x64_DISABLE: @@ -1891,20 +2177,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) &&POST_OPS_GELU_TANH_3x64, &&POST_OPS_GELU_ERF_3x64, &&POST_OPS_CLIP_3x64, - &&POST_OPS_DOWNSCALE_3x64 + &&POST_OPS_DOWNSCALE_3x64, + &&POST_OPS_MATRIX_ADD_3x64, + &&POST_OPS_SWISH_3x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -1925,10 +2213,20 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) __m512i c_int32_2p2 = _mm512_setzero_epi32(); __m512i c_int32_2p3 = _mm512_setzero_epi32(); + // gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - + b0 = _mm512_shuffle_epi8( b0, dsmask ); // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); @@ -1936,8 +2234,11 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) a_int32_0 = _mm512_add_epi8( a_int32_0, vec_uint8 ); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b1 = _mm512_shuffle_epi8( b1, dsmask ); b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b2 = _mm512_shuffle_epi8( b2, dsmask ); b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b3 = _mm512_shuffle_epi8( b3, dsmask ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -2398,37 +2699,68 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2466,6 +2798,84 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) // c[2, 48-63] CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x64_DISABLE: @@ -2568,20 +2978,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) &&POST_OPS_GELU_TANH_2x64, &&POST_OPS_GELU_ERF_2x64, &&POST_OPS_CLIP_2x64, - &&POST_OPS_DOWNSCALE_2x64 + &&POST_OPS_DOWNSCALE_2x64, + &&POST_OPS_MATRIX_ADD_2x64, + &&POST_OPS_SWISH_2x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -2596,11 +3008,20 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) __m512i c_int32_1p1 = _mm512_setzero_epi32(); __m512i c_int32_1p2 = _mm512_setzero_epi32(); __m512i c_int32_1p3 = _mm512_setzero_epi32(); + // gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - + b0 = _mm512_shuffle_epi8( b0, dsmask); // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); @@ -2608,8 +3029,11 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) a_int32_0 = _mm512_add_epi8( a_int32_0, vec_uint8 ); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b1 = _mm512_shuffle_epi8( b1, dsmask ); b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b2 = _mm512_shuffle_epi8( b2, dsmask ); b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b3 = _mm512_shuffle_epi8( b3, dsmask ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -2951,37 +3375,68 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3007,6 +3462,66 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) // c[1, 48-63] CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x64_DISABLE: @@ -3085,20 +3600,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) &&POST_OPS_GELU_TANH_1x64, &&POST_OPS_GELU_ERF_1x64, &&POST_OPS_CLIP_1x64, - &&POST_OPS_DOWNSCALE_1x64 + &&POST_OPS_DOWNSCALE_1x64, + &&POST_OPS_MATRIX_ADD_1x64, + &&POST_OPS_SWISH_1x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -3116,7 +3633,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) // Broadcast a[0,kr] a_int32_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - //convert signed int8 to uint8 for VNNI + //convert signed int8 to uint8 for VNNI a_int32_0 = _mm512_add_epi8( a_int32_0, vec_uint8 ); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); @@ -3124,7 +3641,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); @@ -3341,37 +3858,68 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3385,6 +3933,48 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) // c[0, 48-63] CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x64_DISABLE: diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c index ced733e131..a98cfe5e66 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16) &&POST_OPS_GELU_TANH_5xLT16, &&POST_OPS_GELU_ERF_5xLT16, &&POST_OPS_CLIP_5xLT16, - &&POST_OPS_DOWNSCALE_5xLT16 + &&POST_OPS_DOWNSCALE_5xLT16, + &&POST_OPS_MATRIX_ADD_5xLT16, + &&POST_OPS_SWISH_5xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -259,23 +261,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, 0, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, 0, 3, 0, \ selector1, selector2); // c[4,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_4p0, 0, 4, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_4p0, 0, 4, 0, \ selector1, selector2); } } @@ -421,23 +423,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -454,6 +474,76 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16) // c[4, 0-15] CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5xLT16_DISABLE: @@ -515,7 +605,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16) &&POST_OPS_GELU_TANH_4xLT16, &&POST_OPS_GELU_ERF_4xLT16, &&POST_OPS_CLIP_4xLT16, - &&POST_OPS_DOWNSCALE_4xLT16 + &&POST_OPS_DOWNSCALE_4xLT16, + &&POST_OPS_MATRIX_ADD_4xLT16, + &&POST_OPS_SWISH_4xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -690,19 +782,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, 0, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, 0, 3, 0, \ selector1, selector2); } } @@ -830,23 +922,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -860,6 +970,67 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16) // c[3, 0-15] CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xLT16_DISABLE: @@ -915,7 +1086,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16) &&POST_OPS_GELU_TANH_3xLT16, &&POST_OPS_GELU_ERF_3xLT16, &&POST_OPS_CLIP_3xLT16, - &&POST_OPS_DOWNSCALE_3xLT16 + &&POST_OPS_DOWNSCALE_3xLT16, + &&POST_OPS_MATRIX_ADD_3xLT16, + &&POST_OPS_SWISH_3xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1060,15 +1233,15 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); } } @@ -1178,23 +1351,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -1205,6 +1396,58 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16) // c[2, 0-15] CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3xLT16_DISABLE: @@ -1254,7 +1497,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16) &&POST_OPS_GELU_TANH_2xLT16, &&POST_OPS_GELU_ERF_2xLT16, &&POST_OPS_CLIP_2xLT16, - &&POST_OPS_DOWNSCALE_2xLT16 + &&POST_OPS_DOWNSCALE_2xLT16, + &&POST_OPS_MATRIX_ADD_2xLT16, + &&POST_OPS_SWISH_2xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1369,11 +1614,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); } } @@ -1465,23 +1710,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -1489,6 +1752,49 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16) // c[1, 0-15] CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xLT16_DISABLE: @@ -1532,7 +1838,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16) &&POST_OPS_GELU_TANH_1xLT16, &&POST_OPS_GELU_ERF_1xLT16, &&POST_OPS_CLIP_1xLT16, - &&POST_OPS_DOWNSCALE_1xLT16 + &&POST_OPS_DOWNSCALE_1xLT16, + &&POST_OPS_MATRIX_ADD_1xLT16, + &&POST_OPS_SWISH_1xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1617,7 +1925,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); } } @@ -1691,27 +1999,79 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xLT16_DISABLE: @@ -1749,7 +2109,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16) &&POST_OPS_GELU_TANH_5x16, &&POST_OPS_GELU_ERF_5x16, &&POST_OPS_CLIP_5x16, - &&POST_OPS_DOWNSCALE_5x16 + &&POST_OPS_DOWNSCALE_5x16, + &&POST_OPS_MATRIX_ADD_5x16, + &&POST_OPS_SWISH_5x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2096,16 +2458,33 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2124,15 +2503,84 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_5x16_DISABLE: - ; - - if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) +POST_OPS_MATRIX_ADD_5x16: { - // Generate a mask16 of all 1's. - selector1 = _mm512_setzero_epi32(); - selector2 = _mm512_set1_epi32( 10 ); - __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector1, selector2 ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + selector1 = _mm512_setzero_epi32(); + selector2 = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector1, selector2 ); // Store the results in downscaled type (int8 instead of int32). // c[0,0-15] @@ -2182,7 +2630,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16) &&POST_OPS_GELU_TANH_4x16, &&POST_OPS_GELU_ERF_4x16, &&POST_OPS_CLIP_4x16, - &&POST_OPS_DOWNSCALE_4x16 + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2479,16 +2929,33 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2502,6 +2969,66 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16) // c[3, 0-15] CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -2556,7 +3083,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16) &&POST_OPS_GELU_TANH_3x16, &&POST_OPS_GELU_ERF_3x16, &&POST_OPS_CLIP_3x16, - &&POST_OPS_DOWNSCALE_3x16 + &&POST_OPS_DOWNSCALE_3x16, + &&POST_OPS_MATRIX_ADD_3x16, + &&POST_OPS_SWISH_3x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2803,16 +3332,33 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2823,6 +3369,57 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16) // c[2, 0-15] CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x16_DISABLE: @@ -2871,7 +3468,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16) &&POST_OPS_GELU_TANH_2x16, &&POST_OPS_GELU_ERF_2x16, &&POST_OPS_CLIP_2x16, - &&POST_OPS_DOWNSCALE_2x16 + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3068,16 +3667,33 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3085,6 +3701,48 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16) // c[1, 0-15] CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -3127,7 +3785,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x16) &&POST_OPS_GELU_TANH_1x16, &&POST_OPS_GELU_ERF_1x16, &&POST_OPS_CLIP_1x16, - &&POST_OPS_DOWNSCALE_1x16 + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3274,20 +3934,70 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -3324,17 +4034,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32) &&POST_OPS_GELU_TANH_5x32, &&POST_OPS_GELU_ERF_5x32, &&POST_OPS_CLIP_5x32, - &&POST_OPS_DOWNSCALE_5x32 + &&POST_OPS_DOWNSCALE_5x32, + &&POST_OPS_MATRIX_ADD_5x32, + &&POST_OPS_SWISH_5x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -3801,23 +4513,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3849,6 +4582,90 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32) // c[4, 16-31] CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x32_DISABLE: @@ -3939,17 +4756,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x32) &&POST_OPS_GELU_TANH_4x32, &&POST_OPS_GELU_ERF_4x32, &&POST_OPS_CLIP_4x32, - &&POST_OPS_DOWNSCALE_4x32 + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -4335,55 +5154,148 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x32) // c[2, 16-31] CLIP_S32_AVX512(c_int32_2p1, min, max) - // c[3, 0-15] - CLIP_S32_AVX512(c_int32_3p0, min, max) + // c[3, 0-15] + CLIP_S32_AVX512(c_int32_3p0, min, max) + + // c[3, 16-31] + CLIP_S32_AVX512(c_int32_3p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + // c[0, 0-15] + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + + // c[0, 16-31] + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + + // c[1, 0-15] + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + + // c[1, 16-31] + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + + // c[2, 0-15] + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + + // c[2, 16-31] + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + + // c[3, 0-15] + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + + // c[3, 16-31] + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); - // c[3, 16-31] - CLIP_S32_AVX512(c_int32_3p1, min, max) + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - -POST_OPS_DOWNSCALE_4x32: +POST_OPS_SWISH_4x32: { selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4463,17 +5375,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32) &&POST_OPS_GELU_TANH_3x32, &&POST_OPS_GELU_ERF_3x32, &&POST_OPS_CLIP_3x32, - &&POST_OPS_DOWNSCALE_3x32 + &&POST_OPS_DOWNSCALE_3x32, + &&POST_OPS_MATRIX_ADD_3x32, + &&POST_OPS_SWISH_3x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -4794,23 +5708,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -4830,6 +5765,66 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32) // c[2, 16-31] CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x32_DISABLE: @@ -4896,17 +5891,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32) &&POST_OPS_GELU_TANH_2x32, &&POST_OPS_GELU_ERF_2x32, &&POST_OPS_CLIP_2x32, - &&POST_OPS_DOWNSCALE_2x32 + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -5154,23 +6151,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -5184,6 +6202,54 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32) // c[1, 16-31] CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -5238,17 +6304,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32) &&POST_OPS_GELU_TANH_1x32, &&POST_OPS_GELU_ERF_1x32, &&POST_OPS_CLIP_1x32, - &&POST_OPS_DOWNSCALE_1x32 + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -5423,23 +6491,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -5447,6 +6536,42 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32) // c[0, 16-31] CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: @@ -5489,18 +6614,20 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x48) &&POST_OPS_GELU_TANH_5x48, &&POST_OPS_GELU_ERF_5x48, &&POST_OPS_CLIP_5x48, - &&POST_OPS_DOWNSCALE_5x48 + &&POST_OPS_DOWNSCALE_5x48, + &&POST_OPS_MATRIX_ADD_5x48, + &&POST_OPS_SWISH_5x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -6090,75 +7217,199 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x48: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + // c[0, 0-15] + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + + // c[0, 16-31] + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + + // c[0, 32-47] + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + + // c[1, 0-15] + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + + // c[1, 16-31] + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + + // c[1, 32-47] + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + + // c[2, 0-15] + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + + // c[2, 16-31] + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + + // c[2, 32-47] + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + + // c[3, 0-15] + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + + // c[3, 16-31] + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + + // c[3, 32-47] + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + + // c[4, 0-15] + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); + + // c[4, 16-31] + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + + // c[4, 32-47] + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x48: { selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -6280,18 +7531,20 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48) &&POST_OPS_GELU_TANH_4x48, &&POST_OPS_GELU_ERF_4x48, &&POST_OPS_CLIP_4x48, - &&POST_OPS_DOWNSCALE_4x48 + &&POST_OPS_DOWNSCALE_4x48, + &&POST_OPS_MATRIX_ADD_4x48, + &&POST_OPS_SWISH_4x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -6785,30 +8038,55 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -6846,6 +8124,90 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48) // c[3, 32-47] CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x48_DISABLE: @@ -6948,18 +8310,20 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48) &&POST_OPS_GELU_TANH_3x48, &&POST_OPS_GELU_ERF_3x48, &&POST_OPS_CLIP_3x48, - &&POST_OPS_DOWNSCALE_3x48 + &&POST_OPS_DOWNSCALE_3x48, + &&POST_OPS_MATRIX_ADD_3x48, + &&POST_OPS_SWISH_3x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -7357,30 +8721,55 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -7409,6 +8798,75 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48) // c[2, 32-47] CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x48_DISABLE: @@ -7493,18 +8951,20 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48) &&POST_OPS_GELU_TANH_2x48, &&POST_OPS_GELU_ERF_2x48, &&POST_OPS_CLIP_2x48, - &&POST_OPS_DOWNSCALE_2x48 + &&POST_OPS_DOWNSCALE_2x48, + &&POST_OPS_MATRIX_ADD_2x48, + &&POST_OPS_SWISH_2x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -7807,30 +9267,55 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -7850,6 +9335,60 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48) // c[1, 32-47] CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x48_DISABLE: @@ -7916,7 +9455,9 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48) &&POST_OPS_GELU_TANH_1x48, &&POST_OPS_GELU_ERF_1x48, &&POST_OPS_CLIP_1x48, - &&POST_OPS_DOWNSCALE_1x48 + &&POST_OPS_DOWNSCALE_1x48, + &&POST_OPS_MATRIX_ADD_1x48, + &&POST_OPS_SWISH_1x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -7925,12 +9466,12 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48) __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -8134,30 +9675,55 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -8168,6 +9734,45 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48) // c[0, 32-47] CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x48_DISABLE: diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c index 9669b638b5..8bdd351de0 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) &&POST_OPS_GELU_TANH_6xLT16, &&POST_OPS_GELU_ERF_6xLT16, &&POST_OPS_CLIP_6xLT16, - &&POST_OPS_DOWNSCALE_6xLT16 + &&POST_OPS_DOWNSCALE_6xLT16, + &&POST_OPS_MATRIX_ADD_6xLT16, + &&POST_OPS_SWISH_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -64,10 +66,10 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -329,27 +331,27 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, ir, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, ir, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, ir, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, ir, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, ir, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, ir, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, ir, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, ir, 3, 0, \ selector1, selector2); // c[4,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_4p0, ir, 4, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_4p0, ir, 4, 0, \ selector1, selector2); // c[5,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_5p0, ir, 5, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_5p0, ir, 5, 0, \ selector1, selector2); } } @@ -513,23 +515,41 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -549,6 +569,85 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) // c[5, 0-15] CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xLT16_DISABLE: @@ -689,7 +788,9 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16) &&POST_OPS_GELU_TANH_6x16, &&POST_OPS_GELU_ERF_6x16, &&POST_OPS_CLIP_6x16, - &&POST_OPS_DOWNSCALE_6x16 + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -700,10 +801,10 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -1129,16 +1230,33 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1160,6 +1278,84 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16_DISABLE: ; @@ -1298,7 +1494,9 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32) &&POST_OPS_GELU_TANH_6x32, &&POST_OPS_GELU_ERF_6x32, &&POST_OPS_CLIP_6x32, - &&POST_OPS_DOWNSCALE_6x32 + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1309,11 +1507,11 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -1885,23 +2083,44 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1941,6 +2160,102 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x32_DISABLE: ; @@ -2114,7 +2429,9 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48) &&POST_OPS_GELU_TANH_6x48, &&POST_OPS_GELU_ERF_6x48, &&POST_OPS_CLIP_6x48, - &&POST_OPS_DOWNSCALE_6x48 + &&POST_OPS_DOWNSCALE_6x48, + &&POST_OPS_MATRIX_ADD_6x48, + &&POST_OPS_SWISH_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -2125,12 +2442,12 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); uint8_t cvt_uint8 = 128; __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); @@ -2849,30 +3166,55 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2930,6 +3272,120 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + + // c[5:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + + // c[5:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 32-47] + SWISH_S32_AVX512(c_int32_5p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48_DISABLE: ; diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_packb_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_packb_s8_amd512vnni.c index 532f2c264b..f815f5f209 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_packb_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_packb_s8_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,335 +38,426 @@ #ifdef BLIS_ADDON_LPGEMM -#define NR 64 - -void packb_nrlt16_s8s8s32os32 - ( - int8_t* pack_b_buffer_s8s8s32o32, - int32_t* pack_b_column_sum, - const int8_t* b, - const dim_t ldb, - const dim_t KC, - const dim_t n0_partial_rem - ); - -void packb_nr16_s8s8s32os32 - ( - int8_t* pack_b_buffer_s8s8s32o32, - int32_t* pack_b_column_sum, - const int8_t* b, - const dim_t ldb, - const dim_t KC - ); - -void packb_nr32_s8s8s32os32 - ( - int8_t* pack_b_buffer_s8s8s32o32, - int32_t* pack_b_column_sum, - const int8_t* b, - const dim_t ldb, - const dim_t KC - ); - -void packb_nr48_s8s8s32os32 - ( - int8_t* pack_b_buffer_s8s8s32o32, - int32_t* pack_b_column_sum, - const int8_t* b, - const dim_t ldb, - const dim_t KC - ); +void packb_nrlt16_s8s8s32os32_row_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr16_s8s8s32os32_row_major + ( + int8_t* pack_b_buffer_s8s8s32o32, + int32_t* pack_b_column_sum, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr32_s8s8s32os32_row_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr48_s8s8s32os32_row_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr_mult_16_s8s8s32o32_col_major + ( + int8_t *pack_b_buffer, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t NR, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr64_s8s8s32os32_col_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t *rs_b, + dim_t *cs_b + ); + +void packb_nrlt16_s8s8s32o32_col_major + ( + int8_t *pack_b_buffer, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr64_s8s8s32os32_row_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t *rs_b, + dim_t *cs_b + ); void packb_nr64_s8s8s32os32 - ( - int8_t* pack_b_buffer_s8s8s32o32, - int32_t* pack_b_column_sum, - const int8_t* b, - const dim_t ldb, - const dim_t NC, - const dim_t KC, - dim_t* rs_b, - dim_t* cs_b - ) + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p + ) +{ + if (cs_b == 1) + { + packb_nr64_s8s8s32os32_row_major(pack_b_buffer_s8s8s32o32, + pack_b_column_sum, b, + rs_b, NC, KC, rs_p, cs_p); + } + else + { + packb_nr64_s8s8s32os32_col_major(pack_b_buffer_s8s8s32o32, + pack_b_column_sum, b, + cs_b, NC, KC, rs_p, cs_p); + } +} + +void packb_nr64_s8s8s32os32_row_major + ( + int8_t *pack_b_buffer_s8s8s32o32, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t *rs_b, + dim_t *cs_b + ) { - // Used for permuting the mm512i elements for use in vpdpbusd instruction. - // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. - // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. - __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); - __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); - - __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); - __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); - - dim_t n_full_pieces = NC / NR; - dim_t n_full_pieces_loop_limit = n_full_pieces * NR; - dim_t n_partial_pieces = NC % NR; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. - dim_t KC_updated = KC; - if ( k_partial_pieces > 0 ) - { - KC_updated += ( 4 - k_partial_pieces ); - } - - //to compute column sum of B matrix - __m512i sum1, sum2, sum3, sum4; - __m512i mul_128 = _mm512_set1_epi32 (7); - - __m512i a0; - __m512i b0; - __m512i c0; - __m512i d0; - __m512i a01; - __m512i c01; - - for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) - { - //load the temp buffer to compute column sum of B matrix - sum1 = _mm512_loadu_si512( pack_b_column_sum + jc ); - sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc ); //offset 16- as 16 int32 elements fit in 1 zmm register - sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 + jc ); - sum4 = _mm512_loadu_si512( pack_b_column_sum + 48 + jc ); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. - a0 = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( kr + 1 ) ) + jc ); - c0 = _mm512_loadu_si512( b + ( ldb * ( kr + 2 ) ) + jc ); - d0 = _mm512_loadu_si512( b + ( ldb * ( kr + 3 ) ) + jc ); - - //add all the columns : sum = add (sum, a0, b0, c0, d0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 0))))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 1))))) , mul_128)); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 2))))) , mul_128)); - - sum4 = _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 3)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 3))))), mul_128)); - - a01 = _mm512_unpacklo_epi8( a0, b0 ); - a0 = _mm512_unpackhi_epi8( a0, b0 ); - - c01 = _mm512_unpacklo_epi8( c0, d0 ); - c0 = _mm512_unpackhi_epi8( c0, d0 ); - - b0 = _mm512_unpacklo_epi16( a01, c01 ); - a01 = _mm512_unpackhi_epi16( a01, c01 ); - - d0 = _mm512_unpacklo_epi16( a0, c0 ); - c01 = _mm512_unpackhi_epi16( a0, c0 ); - - a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); - c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); - b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); - d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); - - a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] - c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] - a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] - c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] - - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); - c0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 2 ) ) + jc ); - d0 = _mm512_setzero_si512(); - - //add all the columns : sum = add (sum, a0, b0, c0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)))), mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)))), mul_128)); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)))), mul_128)); - - sum4 = _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 3)))), mul_128)); - - } - else if( k_partial_pieces == 2 ) - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); - c0 = _mm512_setzero_si512(); - d0 = _mm512_setzero_si512(); - - //add all the columns : sum = add (sum, a0, b0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0))), mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1))), mul_128)); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2))), mul_128)); - - sum4 = _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3))), mul_128)); - } - else //k_partial_pieces == 1 - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_setzero_si512(); - c0 = _mm512_setzero_si512(); - d0 = _mm512_setzero_si512(); - - //add all the columns: sum = add (sum, a0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), mul_128)); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), mul_128)); - - sum4 = _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), mul_128)); - } - - a01 = _mm512_unpacklo_epi8( a0, b0 ); - a0 = _mm512_unpackhi_epi8( a0, b0 ); - - c01 = _mm512_unpacklo_epi8( c0, d0 ); - c0 = _mm512_unpackhi_epi8( c0, d0 ); - - b0 = _mm512_unpacklo_epi16( a01, c01 ); - a01 = _mm512_unpackhi_epi16( a01, c01 ); - - d0 = _mm512_unpacklo_epi16( a0, c0 ); - c01 = _mm512_unpackhi_epi16( a0, c0 ); - - a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); - c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); - b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); - d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); - - a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] - c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] - a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] - c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] - - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); - } - //store the sum column - _mm512_storeu_si512( pack_b_column_sum + jc, sum1 ); - _mm512_storeu_si512( pack_b_column_sum + 16 + jc, sum2 ); - _mm512_storeu_si512( pack_b_column_sum + 32 + jc, sum3 ); - _mm512_storeu_si512( pack_b_column_sum + 48 + jc, sum4 ); - } - - // Contiguous packing of fringe panel (n` < NR). - if ( n_partial_pieces > 0 ) - { - dim_t n0_partial_rem = n_partial_pieces % 16; - dim_t n0_partial_pack = 0; - - // Split into multiple smaller fringe kernels, so as to maximize - // vectorization after packing. Any n0 < NR(64) can be expressed - // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. - dim_t n0_48 = n_partial_pieces / 48; - dim_t n0_32 = n_partial_pieces / 32; - dim_t n0_16 = n_partial_pieces / 16; - - if ( n0_48 == 1 ) - { - packb_nr48_s8s8s32os32 - ( - ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( pack_b_column_sum + n_full_pieces_loop_limit ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 48; - } - else if ( n0_32 == 1 ) - { - packb_nr32_s8s8s32os32 - ( - ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( pack_b_column_sum + n_full_pieces_loop_limit ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 32; - } - else if ( n0_16 == 1 ) - { - packb_nr16_s8s8s32os32 - ( - ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( pack_b_column_sum + n_full_pieces_loop_limit ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 16; - } - - if ( n0_partial_rem > 0 ) - { - packb_nrlt16_s8s8s32os32 - ( - ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) + - ( n0_partial_pack * KC_updated ) ), - ( pack_b_column_sum + n_full_pieces_loop_limit + n0_partial_pack ), - ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, - n0_partial_rem - ); - } - } - *rs_b = NR * 4; - *cs_b = NR; + dim_t NR = 64; + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 4 - k_partial_pieces ); + } + + //to compute column sum of B matrix + __m512i sum1, sum2, sum3, sum4; + __m512i mul_128 = _mm512_set1_epi32 (7); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + //load the temp buffer to compute column sum of B matrix + sum1 = _mm512_loadu_si512( pack_b_column_sum + jc ); + sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc ); + //offset 16- as 16 int32 elements fit in 1 zmm register + sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 + jc ); + sum4 = _mm512_loadu_si512( pack_b_column_sum + 48 + jc ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a0 = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( ldb * ( kr + 1 ) ) + jc ); + c0 = _mm512_loadu_si512( b + ( ldb * ( kr + 2 ) ) + jc ); + d0 = _mm512_loadu_si512( b + ( ldb * ( kr + 3 ) ) + jc ); + + //add all the columns : sum = add (sum, a0, b0, c0, d0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 0))))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 1))))) , mul_128)); + + sum3 = + _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 2))))) , mul_128)); + + sum4 = + _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 3)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 3))))), mul_128)); + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 2 ) ) + jc ); + d0 = _mm512_setzero_si512(); + + //add all the columns : sum = add (sum, a0, b0, c0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)))), mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)))), mul_128)); + + sum3 = + _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)))), mul_128)); + + sum4 = + _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 3)))), mul_128)); + + } + else if( k_partial_pieces == 2 ) + { + a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + + //add all the columns : sum = add (sum, a0, b0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0))), mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1))), mul_128)); + + sum3 = + _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2))), mul_128)); + + sum4 = + _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3))), mul_128)); + } + else //k_partial_pieces == 1 + { + a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_setzero_si512(); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + + //add all the columns: sum = add (sum, a0) + sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)), mul_128)); + + sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)), mul_128)); + + sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)), mul_128)); + + sum4 = _mm512_add_epi32 ( sum4, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)), mul_128)); + } + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); + } + //store the sum column + _mm512_storeu_si512( pack_b_column_sum + jc, sum1 ); + _mm512_storeu_si512( pack_b_column_sum + 16 + jc, sum2 ); + _mm512_storeu_si512( pack_b_column_sum + 32 + jc, sum3 ); + _mm512_storeu_si512( pack_b_column_sum + 48 + jc, sum4 ); + } + + // Contiguous packing of fringe panel (n` < NR). + if ( n_partial_pieces > 0 ) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_s8s8s32os32_row_major + ( + ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( pack_b_column_sum + n_full_pieces_loop_limit ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_s8s8s32os32_row_major + ( + ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( pack_b_column_sum + n_full_pieces_loop_limit ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_s8s8s32os32_row_major + ( + ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( pack_b_column_sum + n_full_pieces_loop_limit ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_s8s8s32os32_row_major + ( + ( pack_b_buffer_s8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( pack_b_column_sum + n_full_pieces_loop_limit + n0_partial_pack ), + ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 4; + *cs_b = NR; } -void packb_nr48_s8s8s32os32 +void packb_nr48_s8s8s32os32_row_major ( int8_t* pack_b_buffer_s8s8s32o32, int32_t* pack_b_column_sum, @@ -375,246 +466,266 @@ void packb_nr48_s8s8s32os32 const dim_t KC ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m256i a0_32; - __m256i b0_32; - __m256i c0_32; - __m256i d0_32; - __m256i a01_32; - __m256i c01_32; - __m512i a0_zmm; - __m512i b0_zmm; - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - - //to compute column sum of B matrix - __m512i sum1, sum2, sum3; - __m512i mul_128 = _mm512_set1_epi32 (7); - - //load the temp buffer to compute column sum of B matrix - sum1 = _mm512_loadu_si512( pack_b_column_sum ); - sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 ); //offset 16- as 16 int32 elements fit in 1 zmm register - sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 ); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); - - //add all the columns : sum = add (sum, a0, b0, c0, d0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 1))))) , mul_128)); - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) + ( 32 ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) + ( 32 ) ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) + ( 32 ) ); - - //add all the columns : sum = add (sum, a0_32, b0_32, c0_32, d0_32) - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), - _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); - - // The 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 3; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0, c0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)))) , mul_128)); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) ); - d0_16 = _mm_setzero_si128(); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_cvtepi8_epi32( c0_16 ))) , mul_128)); - - } - else if( k_partial_pieces == 2 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 0) )) , mul_128 )); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 1) )) , mul_128 )); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_cvtepi8_epi32( b0_16 )) , mul_128)); - } - else //k_partial_pieces == 1 - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_setzero_si256(); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128)); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum3 = _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32( a0_16 ) , mul_128)); - } - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); - } - //store the sum column - _mm512_storeu_si512( pack_b_column_sum, sum1 ); - _mm512_storeu_si512( pack_b_column_sum + 16, sum2 ); - _mm512_storeu_si512( pack_b_column_sum + 32, sum3 ); + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + + //to compute column sum of B matrix + __m512i sum1, sum2, sum3; + __m512i mul_128 = _mm512_set1_epi32 (7); + + //load the temp buffer to compute column sum of B matrix + sum1 = _mm512_loadu_si512( pack_b_column_sum ); + sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 ); + sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); + + //add all the columns : sum = add (sum, a0, b0, c0, d0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 1))))) , mul_128)); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) + ( 32 ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) + ( 32 ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) + ( 32 ) ); + d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) + ( 32 ) ); + + //add all the columns : sum = add (sum, a0_32, b0_32, c0_32, d0_32) + sum3 = + _mm512_add_epi32 + ( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), + _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 ) + ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + + // The 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0, c0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)))) , mul_128)); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) ); + d0_16 = _mm_setzero_si128(); + + sum3 = + _mm512_add_epi32 + ( sum3, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_cvtepi8_epi32( c0_16 ))) , mul_128) + ); + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 0) )) , mul_128 )); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 1) )) , mul_128 )); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum3 = + _mm512_add_epi32 + ( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_cvtepi8_epi32( b0_16 )) , mul_128) + ); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128)); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum3 = + _mm512_add_epi32 ( sum3, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32( a0_16 ) , mul_128)); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + } + //store the sum column + _mm512_storeu_si512( pack_b_column_sum, sum1 ); + _mm512_storeu_si512( pack_b_column_sum + 16, sum2 ); + _mm512_storeu_si512( pack_b_column_sum + 32, sum3 ); } -void packb_nr32_s8s8s32os32 +void packb_nr32_s8s8s32os32_row_major ( int8_t* pack_b_buffer_s8s8s32o32, int32_t* pack_b_column_sum, @@ -623,165 +734,174 @@ void packb_nr32_s8s8s32os32 const dim_t KC ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m256i a0_32; - __m256i b0_32; - __m256i c0_32; - __m256i d0_32; - __m256i a01_32; - __m256i c01_32; - __m512i a0_zmm; - __m512i b0_zmm; - - //to compute column sum of B matrix - __m512i sum1, sum2; - __m512i mul_128 = _mm512_set1_epi32 (7); - - //load the temp buffer to compute column sum of B matrix - sum1 = _mm512_loadu_si512( pack_b_column_sum ); - sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 ); //offset 16- as 16 int32 elements fit in 1 zmm register - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); - - //add all the columns : sum = add (sum, a0, b0, c0, d0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), - _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 1))))) , mul_128)); - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - // The 3rd and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 2; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0, c0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)))) , mul_128)); - - } - else if( k_partial_pieces == 2 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0))) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1))) , mul_128)); - } - else //k_partial_pieces == 1 - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_setzero_si256(); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - //add all the columns : sum = add (sum, a0, b0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128)); - - sum2 = _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( - _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128)); - } - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - } - //store the sum column - _mm512_storeu_si512( pack_b_column_sum, sum1 ); - _mm512_storeu_si512( pack_b_column_sum + 16, sum2 ); + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + + //to compute column sum of B matrix + __m512i sum1, sum2; + __m512i mul_128 = _mm512_set1_epi32 (7); + + //load the temp buffer to compute column sum of B matrix + sum1 = _mm512_loadu_si512( pack_b_column_sum ); + sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 ); //offset 16- as 16 int32 elements fit in 1 zmm register + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); + + //add all the columns : sum = add (sum, a0, b0, c0, d0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), + _mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 1))))) , mul_128)); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // The 3rd and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0, c0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 1)))) , mul_128)); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0))) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)), + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1))) , mul_128)); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + //add all the columns : sum = add (sum, a0, b0) + sum1 = + _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128)); + + sum2 = + _mm512_add_epi32 ( sum2, _mm512_sllv_epi32 ( + _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128)); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + } + //store the sum column + _mm512_storeu_si512( pack_b_column_sum, sum1 ); + _mm512_storeu_si512( pack_b_column_sum + 16, sum2 ); } -void packb_nr16_s8s8s32os32 +void packb_nr16_s8s8s32os32_row_major ( int8_t* pack_b_buffer_s8s8s32o32, int32_t* pack_b_column_sum, @@ -790,121 +910,136 @@ void packb_nr16_s8s8s32os32 const dim_t KC ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - __m512i a0_zmm; - - //to compute column sum of B matrix - __m512i sum1; - __m512i mul_128 = _mm512_set1_epi32 (7); - - //load the temp buffer to compute column sum of B matrix - sum1 = _mm512_loadu_si512( pack_b_column_sum ); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) ); - - //add all the columns : sum = add (sum, a0, b0, c0, d0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), - _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - - // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 1; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_cvtepi8_epi32( c0_16 ))) , mul_128)); - - } - else if( k_partial_pieces == 2 ) - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_cvtepi8_epi32( b0_16 )) , mul_128)); - } - else //k_partial_pieces == 1 - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )); - } - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - } - //store the sum column - _mm512_storeu_si512( pack_b_column_sum, sum1 ); + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + //to compute column sum of B matrix + __m512i sum1; + __m512i mul_128 = _mm512_set1_epi32 (7); + + //load the temp buffer to compute column sum of B matrix + sum1 = _mm512_loadu_si512( pack_b_column_sum ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) ); + d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) ); + + //add all the columns : sum = add (sum, a0, b0, c0, d0) + sum1 = + _mm512_add_epi32 + ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), + _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 ) + ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_cvtepi8_epi32( c0_16 ))) , mul_128) + ); + } + else if( k_partial_pieces == 2 ) + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_cvtepi8_epi32( b0_16 )) , mul_128) + ); + } + else //k_partial_pieces == 1 + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, + _mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 ) + ); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } + //store the sum column + _mm512_storeu_si512( pack_b_column_sum, sum1 ); } -void packb_nrlt16_s8s8s32os32 +void packb_nrlt16_s8s8s32os32_row_major ( int8_t* pack_b_buffer_s8s8s32o32, int32_t* pack_b_column_sum, @@ -914,136 +1049,867 @@ void packb_nrlt16_s8s8s32os32 const dim_t n0_partial_rem ) { - int8_t buf0[16]; - int8_t buf1[16]; - int8_t buf2[16]; - int8_t buf3[16]; - - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - __m512i a0_zmm; - - //to compute column sum of B matrix - __m512i sum1; - __m512i mul_128 = _mm512_set1_epi32 (7); - - //load the temp buffer to compute column sum of B matrix - sum1 = _mm512_loadu_si512( pack_b_column_sum ); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 ); - - //add all the columns : sum = add (sum, a0, b0, c0, d0) - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), - _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - - // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 1; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), _mm512_cvtepi8_epi32( c0_16 ))) , mul_128)); - - } - else if( k_partial_pieces == 2 ) - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), - _mm512_cvtepi8_epi32( b0_16 )) , mul_128)); - } - else //k_partial_pieces == 1 - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - - sum1 = _mm512_add_epi32 ( sum1, _mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )); - } - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - } - //store the sum column - _mm512_storeu_si512( pack_b_column_sum, sum1 ); + dim_t NR = 64; + int8_t buf0[16]; + int8_t buf1[16]; + int8_t buf2[16]; + int8_t buf3[16]; + + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + //to compute column sum of B matrix + __m512i sum1; + __m512i mul_128 = _mm512_set1_epi32 (7); + + //load the temp buffer to compute column sum of B matrix + sum1 = _mm512_loadu_si512( pack_b_column_sum ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); + d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 ); + + //add all the columns : sum = add (sum, a0, b0, c0, d0) + sum1 = + _mm512_add_epi32 + ( sum1, + _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ), + _mm512_cvtepi8_epi32( d0_16 )))) , mul_128 ) + ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the + // original data, but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, + _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ), + _mm512_cvtepi8_epi32( c0_16 ))) , mul_128) + ); + + } + else if( k_partial_pieces == 2 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ), + _mm512_cvtepi8_epi32( b0_16 )) , mul_128) + ); + } + else //k_partial_pieces == 1 + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + sum1 = + _mm512_add_epi32 + ( sum1, + _mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 ) + ); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } + //store the sum column + _mm512_storeu_si512( pack_b_column_sum, sum1 ); +} + +#define LOAD_16_COLS_AVX512 \ + a_reg[0] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); \ + a_reg[1] = _mm512_loadu_si512(b + (ldb * (jr + 1)) + kr); \ + a_reg[2] = _mm512_loadu_si512(b + (ldb * (jr + 2)) + kr); \ + a_reg[3] = _mm512_loadu_si512(b + (ldb * (jr + 3)) + kr); \ + a_reg[4] = _mm512_loadu_si512(b + (ldb * (jr + 4)) + kr); \ + a_reg[5] = _mm512_loadu_si512(b + (ldb * (jr + 5)) + kr); \ + a_reg[6] = _mm512_loadu_si512(b + (ldb * (jr + 6)) + kr); \ + a_reg[7] = _mm512_loadu_si512(b + (ldb * (jr + 7)) + kr); \ + a_reg[8] = _mm512_loadu_si512(b + (ldb * (jr + 8)) + kr); \ + a_reg[9] = _mm512_loadu_si512(b + (ldb * (jr + 9)) + kr); \ + a_reg[10] = _mm512_loadu_si512(b + (ldb * (jr + 10)) + kr); \ + a_reg[11] = _mm512_loadu_si512(b + (ldb * (jr + 11)) + kr); \ + a_reg[12] = _mm512_loadu_si512(b + (ldb * (jr + 12)) + kr); \ + a_reg[13] = _mm512_loadu_si512(b + (ldb * (jr + 13)) + kr); \ + a_reg[14] = _mm512_loadu_si512(b + (ldb * (jr + 14)) + kr); \ + a_reg[15] = _mm512_loadu_si512(b + (ldb * (jr + 15)) + kr); + +#define UNPACKHILO32_AVX512 \ + b_reg[0] = _mm512_unpacklo_epi32(a_reg[0], a_reg[1]); \ + b_reg[2] = _mm512_unpacklo_epi32(a_reg[2], a_reg[3]); \ + b_reg[4] = _mm512_unpacklo_epi32(a_reg[4], a_reg[5]); \ + b_reg[6] = _mm512_unpacklo_epi32(a_reg[6], a_reg[7]); \ + b_reg[8] = _mm512_unpacklo_epi32(a_reg[8], a_reg[9]); \ + b_reg[10] = _mm512_unpacklo_epi32(a_reg[10], a_reg[11]); \ + b_reg[12] = _mm512_unpacklo_epi32(a_reg[12], a_reg[13]); \ + b_reg[14] = _mm512_unpacklo_epi32(a_reg[14], a_reg[15]); \ + \ + b_reg[1] = _mm512_unpackhi_epi32(a_reg[0], a_reg[1]); \ + b_reg[3] = _mm512_unpackhi_epi32(a_reg[2], a_reg[3]); \ + b_reg[5] = _mm512_unpackhi_epi32(a_reg[4], a_reg[5]); \ + b_reg[7] = _mm512_unpackhi_epi32(a_reg[6], a_reg[7]); \ + b_reg[9] = _mm512_unpackhi_epi32(a_reg[8], a_reg[9]); \ + b_reg[11] = _mm512_unpackhi_epi32(a_reg[10], a_reg[11]); \ + b_reg[13] = _mm512_unpackhi_epi32(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm512_unpackhi_epi32(a_reg[14], a_reg[15]); + +#define UNPACKHILO64_AVX512 \ + a_reg[0] = _mm512_unpacklo_epi64(b_reg[0], b_reg[2]); \ + a_reg[1] = _mm512_unpacklo_epi64(b_reg[4], b_reg[6]); \ + a_reg[2] = _mm512_unpacklo_epi64(b_reg[8], b_reg[10]); \ + a_reg[3] = _mm512_unpacklo_epi64(b_reg[12], b_reg[14]); \ + a_reg[4] = _mm512_unpacklo_epi64(b_reg[1], b_reg[3]); \ + a_reg[5] = _mm512_unpacklo_epi64(b_reg[5], b_reg[7]); \ + a_reg[6] = _mm512_unpacklo_epi64(b_reg[9], b_reg[11]); \ + a_reg[7] = _mm512_unpacklo_epi64(b_reg[13], b_reg[15]); \ + \ + a_reg[8] = _mm512_unpackhi_epi64(b_reg[0], b_reg[2]); \ + a_reg[9] = _mm512_unpackhi_epi64(b_reg[4], b_reg[6]); \ + a_reg[10] = _mm512_unpackhi_epi64(b_reg[8], b_reg[10]); \ + a_reg[11] = _mm512_unpackhi_epi64(b_reg[12], b_reg[14]); \ + a_reg[12] = _mm512_unpackhi_epi64(b_reg[1], b_reg[3]); \ + a_reg[13] = _mm512_unpackhi_epi64(b_reg[5], b_reg[7]); \ + a_reg[14] = _mm512_unpackhi_epi64(b_reg[9], b_reg[11]); \ + a_reg[15] = _mm512_unpackhi_epi64(b_reg[13], b_reg[15]); + +#define PERMUTEX2_VAR64_AVX512 \ + b_reg[0] = _mm512_permutex2var_epi64(a_reg[0], selector1, a_reg[1]); \ + b_reg[1] = _mm512_permutex2var_epi64(a_reg[2], selector1, a_reg[3]); \ + b_reg[2] = _mm512_permutex2var_epi64(a_reg[8], selector1, a_reg[9]); \ + b_reg[3] = _mm512_permutex2var_epi64(a_reg[10], selector1, a_reg[11]); \ + b_reg[4] = _mm512_permutex2var_epi64(a_reg[4], selector1, a_reg[5]); \ + b_reg[5] = _mm512_permutex2var_epi64(a_reg[6], selector1, a_reg[7]); \ + b_reg[6] = _mm512_permutex2var_epi64(a_reg[12], selector1, a_reg[13]); \ + b_reg[7] = _mm512_permutex2var_epi64(a_reg[14], selector1, a_reg[15]); \ + b_reg[8] = _mm512_permutex2var_epi64(a_reg[0], selector2, a_reg[1]); \ + b_reg[9] = _mm512_permutex2var_epi64(a_reg[2], selector2, a_reg[3]); \ + b_reg[10] = _mm512_permutex2var_epi64(a_reg[8], selector2, a_reg[9]); \ + b_reg[11] = _mm512_permutex2var_epi64(a_reg[10], selector2, a_reg[11]); \ + b_reg[12] = _mm512_permutex2var_epi64(a_reg[4], selector2, a_reg[5]); \ + b_reg[13] = _mm512_permutex2var_epi64(a_reg[6], selector2, a_reg[7]); \ + b_reg[14] = _mm512_permutex2var_epi64(a_reg[12], selector2, a_reg[13]); \ + b_reg[15] = _mm512_permutex2var_epi64(a_reg[14], selector2, a_reg[15]); + +#define SHUFFLE64x2_AVX512 \ + a_reg[0] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0x44); \ + a_reg[1] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0x44); \ + a_reg[2] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0x44); \ + a_reg[3] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0x44); \ + a_reg[4] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0x44); \ + a_reg[5] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0x44); \ + a_reg[6] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0x44); \ + a_reg[7] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0x44); \ + a_reg[8] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0xEE); \ + a_reg[9] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0xEE); \ + a_reg[10] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0xEE); \ + a_reg[11] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0xEE); \ + a_reg[12] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0xEE); \ + a_reg[13] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0xEE); \ + a_reg[14] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0xEE); \ + a_reg[15] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0xEE); + +#define MASK_LOAD_16_COLS_AVX512(mask) \ + a_reg[0] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 0)) + kr); \ + a_reg[1] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 1)) + kr); \ + a_reg[2] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 2)) + kr); \ + a_reg[3] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 3)) + kr); \ + a_reg[4] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 4)) + kr); \ + a_reg[5] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 5)) + kr); \ + a_reg[6] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 6)) + kr); \ + a_reg[7] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 7)) + kr); \ + a_reg[8] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 8)) + kr); \ + a_reg[9] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 9)) + kr); \ + a_reg[10] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 10)) + kr); \ + a_reg[11] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 11)) + kr); \ + a_reg[12] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 12)) + kr); \ + a_reg[13] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 13)) + kr); \ + a_reg[14] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 14)) + kr); \ + a_reg[15] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 15)) + kr); + +void packb_nr64_s8s8s32os32_col_major + ( + int8_t *pack_b_buffer, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p + ) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_partial_pieces = KC % 4; + + dim_t KC_updated = KC; + if (k_partial_pieces > 0) + { + KC_updated += (4 - k_partial_pieces); + } + + for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR) + { + packb_nr_mult_16_s8s8s32o32_col_major + ( + pack_b_buffer + (jc * KC_updated), + pack_b_column_sum + jc, + b + (jc * ldb), 64, ldb, KC + ); + } + + if (n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if (n0_48 == 1) + { + packb_nr_mult_16_s8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + pack_b_column_sum + n_full_pieces_loop_limit, + (b + n_full_pieces_loop_limit * ldb), 48, ldb, KC); + + n0_partial_pack = 48; + } + else if (n0_32 == 1) + { + packb_nr_mult_16_s8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + pack_b_column_sum + n_full_pieces_loop_limit, + (b + n_full_pieces_loop_limit * ldb), 32, ldb, KC); + + n0_partial_pack = 32; + } + else if (n0_16 == 1) + { + packb_nr_mult_16_s8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + pack_b_column_sum + n_full_pieces_loop_limit, + (b + n_full_pieces_loop_limit * ldb), 16, ldb, KC); + + n0_partial_pack = 16; + } + + if (n0_partial_rem > 0) + { + packb_nrlt16_s8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated) + + (n0_partial_pack * KC_updated)), + pack_b_column_sum + n_full_pieces_loop_limit + n0_partial_pack, + (b + (n_full_pieces_loop_limit + n0_partial_pack) * ldb), ldb, KC, + n0_partial_rem); + } + } + + *rs_p = NR * 4; + *cs_p = NR / 4; +} + +//Extract 16 8-bit elements from each 128-bit lane of 512-bit register and convert them into +//32 bit and add to reduce to 16 elements based on K size. + +#define SUM_16_COLS_AVX512_K64 \ + for (dim_t i = 0; i < 16; i++) \ + { \ + __m512i sum0, sum1; \ + sum0 = \ + _mm512_add_epi32 \ + ( \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 0)), \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 1)) \ + ); \ + sum1 = \ + _mm512_add_epi32 \ + ( \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 2)), \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 3)) \ + ); \ + sum[i + jr] = \ + _mm512_add_epi32(sum[i + jr], _mm512_add_epi32(sum0, sum1)); \ + } \ + +#define SUM_16_COLS_AVX512_K32 \ + for (dim_t i = 0; i < 16; i++) \ + { \ + sum[i + jr] = \ + _mm512_add_epi32 \ + ( sum[i + jr ], \ + _mm512_add_epi32(_mm512_cvtepi8_epi32( \ + _mm512_extracti32x4_epi32(a_reg[i], 0)), \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 1))) \ + ); \ + } \ + +#define SUM_16_COLS_AVX512_K16 \ + for (dim_t i = 0; i < 16; i++) \ + { \ + sum[i + jr] = \ + _mm512_add_epi32 \ + ( sum[i + jr], \ + _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32(a_reg[i], 0)) \ + ); \ + } \ + +void packb_nr_mult_16_s8s8s32o32_col_major +( + int8_t *pack_b_buffer, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t NR, + const dim_t ldb, + const dim_t KC) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD); + __m512i selector2 = _mm512_setr_epi64(0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + // to compute column sum of B matrix + __m512i sum[64]; + __m512i mul_128 = _mm512_set1_epi32(7); + + for (dim_t i = 0; i < 64; i++) + { + sum[i] = _mm512_setzero_si512(); + } + + dim_t kr = 0; + for (kr = 0; (kr + 63) < KC; kr += 64) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + LOAD_16_COLS_AVX512 + SUM_16_COLS_AVX512_K64 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 28) * NR), a_reg[7]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 32) * NR), a_reg[8]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 36) * NR), a_reg[9]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 40) * NR), a_reg[10]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 44) * NR), a_reg[11]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 48) * NR), a_reg[12]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 52) * NR), a_reg[13]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 56) * NR), a_reg[14]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 60) * NR), a_reg[15]); + } + } + + for (; (kr + 31) < KC; kr += 32) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512(0xFFFFFFFF) + SUM_16_COLS_AVX512_K32 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 28) * NR), a_reg[7]); + } + } + + for (; (kr + 15) < KC; kr += 16) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0xFFFF) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + } + } + + for (; (kr + 7) < KC; kr += 8) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0xFF) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + } + } + + for (; (kr + 3) < KC; kr += 4) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0x0F) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + (kr * NR), a_reg[0]); + } + } + + for (; (kr + 2) < KC; kr += 3) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512(0x07) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } + + for (; (kr + 1) < KC; kr += 2) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512(0x03) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } + + for (; kr < KC; kr += 1) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512(0x01) + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } + + // sum/reduce 16 int32 values into one final sum as int. + // insert 16 columns into one 512 bit and store into pack_b_column_sum + for (dim_t jr = 0; jr < NR; jr += 16) + { + __m512i sum0, sum1; + sum0 = _mm512_set_epi32 + ( + _mm512_reduce_add_epi32(sum[jr + 15]), _mm512_reduce_add_epi32(sum[jr + 14]), + _mm512_reduce_add_epi32(sum[jr + 13]), _mm512_reduce_add_epi32(sum[jr + 12]), + _mm512_reduce_add_epi32(sum[jr + 11]), _mm512_reduce_add_epi32(sum[jr + 10]), + _mm512_reduce_add_epi32(sum[jr + 9]), _mm512_reduce_add_epi32(sum[jr + 8]), + _mm512_reduce_add_epi32(sum[jr + 7]), _mm512_reduce_add_epi32(sum[jr + 6]), + _mm512_reduce_add_epi32(sum[jr + 5]), _mm512_reduce_add_epi32(sum[jr + 4]), + _mm512_reduce_add_epi32(sum[jr + 3]), _mm512_reduce_add_epi32(sum[jr + 2]), + _mm512_reduce_add_epi32(sum[jr + 1]), _mm512_reduce_add_epi32(sum[jr + 0]) + ); + + sum0 = _mm512_sllv_epi32(sum0, mul_128); + sum1 = _mm512_loadu_si512(pack_b_column_sum + jr); + sum1 = _mm512_add_epi32(sum0, sum1); + _mm512_storeu_si512(pack_b_column_sum + jr, sum1); + } +} + +void packb_nrlt16_s8s8s32o32_col_major + ( + int8_t *pack_b_buffer, + int32_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + dim_t NR = 16; + + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD); + __m512i selector2 = _mm512_setr_epi64(0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + __m512i sum[16]; + __m512i mul_128 = _mm512_set1_epi32(7); + + for (dim_t i = 0; i < 16; i++) + { + sum[i] = _mm512_setzero_si512(); + } + + dim_t kr = 0, jr = 0; + for (kr = 0; (kr + 63) < KC; kr += 64) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K64 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 28) * NR), a_reg[7]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 32) * NR), a_reg[8]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 36) * NR), a_reg[9]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 40) * NR), a_reg[10]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 44) * NR), a_reg[11]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 48) * NR), a_reg[12]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 52) * NR), a_reg[13]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 56) * NR), a_reg[14]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 60) * NR), a_reg[15]); + } + + for (; (kr + 31) < KC; kr += 32) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFFFFFFFF, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K32 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 28) * NR), a_reg[7]); + } + + for (; (kr + 15) < KC; kr += 16) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFFFF, b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + } + + for (; (kr + 7) < KC; kr += 8) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFF, b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + } + + for (; (kr + 3) < KC; kr += 4) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x0F, b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } + + for (; (kr + 2) < KC; kr += 3) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x07, b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } + + for (; (kr + 1) < KC; kr += 2) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16(0x03, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + (kr * NR), a_reg[0]); + } + + for (; kr < KC; kr += 1) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x01, b + (ldb * (jr + 0)) + kr); + } + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + jr = 0; /*Initialize jr=0 as SUM macro expects jr*/ + SUM_16_COLS_AVX512_K16 + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + (kr * NR), a_reg[0]); + } + + // sum/reduce < 16 (max 15) int32 values into one final sum as int. + // insert sum of all columns into one 512 bit, multiply with 128 and + // store into pack_b_column_sum + __m512i sum0, sum1; + sum0 = _mm512_set_epi32 + ( + _mm512_reduce_add_epi32(sum[15]), _mm512_reduce_add_epi32(sum[14]), + _mm512_reduce_add_epi32(sum[13]), _mm512_reduce_add_epi32(sum[12]), + _mm512_reduce_add_epi32(sum[11]), _mm512_reduce_add_epi32(sum[10]), + _mm512_reduce_add_epi32(sum[9]), _mm512_reduce_add_epi32(sum[8]), + _mm512_reduce_add_epi32(sum[7]), _mm512_reduce_add_epi32(sum[6]), + _mm512_reduce_add_epi32(sum[5]), _mm512_reduce_add_epi32(sum[4]), + _mm512_reduce_add_epi32(sum[3]), _mm512_reduce_add_epi32(sum[2]), + _mm512_reduce_add_epi32(sum[1]), _mm512_reduce_add_epi32(sum[0]) + ); + sum0 = _mm512_sllv_epi32(sum0, mul_128); + + sum1 = _mm512_loadu_epi32(pack_b_column_sum); + sum1 = _mm512_add_epi32(sum0, sum1); + _mm512_storeu_si512(pack_b_column_sum, sum1); } + #endif diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemv_m_kernel_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemv_m_kernel_amd512vnni.c new file mode 100644 index 0000000000..7d56b0c9bd --- /dev/null +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemv_m_kernel_amd512vnni.c @@ -0,0 +1,571 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../u8s8s32/lpgemm_s32_kern_macros.h" +#include "../u8s8s32/lpgemm_s32_memcpy_macros.h" + +LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 + }; + + const int8_t *a_use = NULL; + const int8_t *b_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for( dim_t jr = 0; jr < n0; jr += NR ) + { + NR = bli_min( 64, ( ( n0 - jr ) / 16 ) * 16 ); + + if( NR == 0 ) NR = 16; + + rs_b = NR * 4; + dim_t nr0 = bli_min( n0 - jr, NR ); + + int32_t* c_use = c + jr * cs_c; + + __mmask16 k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF, k4 = 0xFFFF; + __mmask32 k5 = 0xFFFFFFFF, k6 = 0xFFFFFFFF; + __mmask32 k7 = 0xFFFFFFFF, k8 = 0xFFFFFFFF; + + + if( nr0 == 64 ) + { + + } + if( nr0 == 48 ) + { + k4 = k8 = 0x0; + } + else if( nr0 == 32 ) + { + k3 = k4 = k7 = k8 = 0x0; + } + else if( nr0 == 16 ) + { + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + else if( nr0 < 16 ) + { + k1 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + + + __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512i zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512i zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512i zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512i zmm29, zmm30, zmm31; + + // zero the accumulator registers + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + + dim_t k_full_pieces = kc0 / 4; + dim_t k_partial_pieces = kc0 % 4; + + dim_t k_iter = kc0 / 16; + dim_t k_rem = k_full_pieces % 4; + + dim_t kc0_updated = kc0; + + if ( k_partial_pieces > 0 ) + { + kc0_updated += ( 4 - k_partial_pieces ); + } + + b_use = b + (n_sub_updated * pc) + + ( ( jc_cur_loop_rem + jr ) * kc0_updated ); + + a_use = a + pc; + + uint8_t cvt_uint8 = 128; + __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); + + for( dim_t kr = 0; kr < k_iter; kr++ ) + { + // load first 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k5, b_use + rs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k5, b_use + 2 * rs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k5, b_use + 3 * rs_b ); + b_use += 64; + + // Broadcast col0-col3 elements of A + zmm4 = _mm512_set1_epi32( *( int32_t* )( a_use ) ); + zmm5 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a ) ); + zmm6 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a * 2 ) ); + zmm7 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a * 3 ) ); + + zmm4 = _mm512_add_epi8( zmm4, vec_uint8 ); + zmm5 = _mm512_add_epi8( zmm5, vec_uint8 ); + zmm6 = _mm512_add_epi8( zmm6, vec_uint8 ); + zmm7 = _mm512_add_epi8( zmm7, vec_uint8 ); + + // Load second 4x64 tile from row 0-3 + zmm24 = _mm512_maskz_loadu_epi16( k6, b_use ); + zmm25 = _mm512_maskz_loadu_epi16( k6, b_use + rs_b ); + zmm26 = _mm512_maskz_loadu_epi16( k6, b_use + 2 * rs_b ); + zmm27 = _mm512_maskz_loadu_epi16( k6, b_use + 3 * rs_b ); + b_use += 64; + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm9 = _mm512_dpbusd_epi32( zmm9, zmm5, zmm1 ); + zmm10 = _mm512_dpbusd_epi32( zmm10, zmm6, zmm2 ); + zmm11 = _mm512_dpbusd_epi32( zmm11, zmm7, zmm3 ); + + // load third 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k7, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k7, b_use + rs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * rs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k7, b_use + 3 * rs_b ); + b_use += 64; + + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm24 ); + zmm13 = _mm512_dpbusd_epi32( zmm13, zmm5, zmm25 ); + zmm14 = _mm512_dpbusd_epi32( zmm14, zmm6, zmm26 ); + zmm15 = _mm512_dpbusd_epi32( zmm15, zmm7, zmm27 ); + + // load third 4x64 tile from row 0-3 + zmm28 = _mm512_maskz_loadu_epi16( k8, b_use ); + zmm29 = _mm512_maskz_loadu_epi16( k8, b_use + rs_b ); + zmm30 = _mm512_maskz_loadu_epi16( k8, b_use + 2 * rs_b ); + zmm31 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * rs_b ); + + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm0 ); + zmm17 = _mm512_dpbusd_epi32( zmm17, zmm5, zmm1 ); + zmm18 = _mm512_dpbusd_epi32( zmm18, zmm6, zmm2 ); + zmm19 = _mm512_dpbusd_epi32( zmm19, zmm7, zmm3 ); + + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm28 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm5, zmm29 ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm6, zmm30 ); + zmm23 = _mm512_dpbusd_epi32( zmm23, zmm7, zmm31 ); + + b_use -= 192; // move b point back to start of KCXNR + b_use += ( 4 * rs_b ); + a_use += 4 * cs_a; // move a pointer to next col + } + for( dim_t kr = 0; kr < k_rem; kr++ ) + { + // load first 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k6, b_use + cs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * cs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * cs_b ); + + // Broadcast col0 elements of A + zmm4 = _mm512_set1_epi32( *( int32_t* )( a_use ) ); + zmm4 = _mm512_add_epi8( zmm4, vec_uint8 ); + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm3 ); + + b_use += rs_b; + a_use += cs_a; // move a pointer to next col + } + if( k_partial_pieces > 0 ) + { + __m128i a_kfringe_buf; + __mmask16 load_mask = + _cvtu32_mask16( 0xFFFF >> ( 16 - k_partial_pieces ) ); + + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + + // Broadcast a[0,kr:kr+4]. + a_kfringe_buf = _mm_maskz_loadu_epi8( load_mask, a_use ); + + zmm4 = _mm512_broadcastd_epi32( a_kfringe_buf ); + zmm4 = _mm512_add_epi8( zmm4, vec_uint8 ); + + zmm1 = _mm512_maskz_loadu_epi16( k6, b_use + cs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * cs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * cs_b ); + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm3 ); + + } + + } + + // Sumup k-unroll outputs + zmm8 = _mm512_add_epi32( zmm9, zmm8 ); + zmm10 = _mm512_add_epi32(zmm11, zmm10); + zmm8 = _mm512_add_epi32(zmm10, zmm8); // 64 outputs + + zmm12 = _mm512_add_epi32(zmm13, zmm12); + zmm14 = _mm512_add_epi32(zmm15, zmm14); + zmm12 = _mm512_add_epi32(zmm14, zmm12); // 64 outputs + + zmm16 = _mm512_add_epi32(zmm17, zmm16); + zmm18 = _mm512_add_epi32(zmm19, zmm18); + zmm16 = _mm512_add_epi32(zmm18, zmm16); // 64 outputs + + zmm20 = _mm512_add_epi32(zmm21, zmm20); + zmm22 = _mm512_add_epi32(zmm23, zmm22); + zmm20 = _mm512_add_epi32(zmm22, zmm20); // 64 outputs + + int32_t* bsumptr = post_ops_attr.b_col_sum_vec + + post_ops_attr.b_sum_offset; + + zmm0 = _mm512_maskz_loadu_epi32( k1, bsumptr ); + zmm1 = _mm512_maskz_loadu_epi32( k2, bsumptr + 16 ); + zmm2 = _mm512_maskz_loadu_epi32( k3, bsumptr + 32 ); + zmm3 = _mm512_maskz_loadu_epi32( k4, bsumptr + 48 ); + + zmm8 = _mm512_sub_epi32( zmm8, zmm0 ); + zmm12 = _mm512_sub_epi32( zmm12, zmm1 ); + zmm16 = _mm512_sub_epi32( zmm16, zmm2 ); + zmm20 = _mm512_sub_epi32( zmm20, zmm3 ); + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + __m512i selector3 = _mm512_setzero_epi32(); + __m512i selector4 = _mm512_setzero_epi32(); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mullo_epi32(selector1, zmm8); + zmm12 = _mm512_mullo_epi32(selector1, zmm12); + zmm16 = _mm512_mullo_epi32(selector1, zmm16); + zmm20 = _mm512_mullo_epi32(selector1, zmm20); + + if (beta != 0) + { + // For the downscaled api (C-s8), the output C matrix values + // needs to be upscaled to s32 to be used for beta scale. + if ( post_ops_attr.buf_downscale != NULL ) + { + S8_S32_BETA_OP_NLT16F_MASK( k1, zmm8, 0, 0, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k2, zmm12, 0, 1, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k3, zmm16, 0, 2, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k4, zmm20, 0, 3, + selector1, selector2 ) + } + else + { + S32_S32_BETA_OP_NLT16F_MASK( c_use, k1, zmm8, 0, 0, 0, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k2, zmm12, 0, 0, 1, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k3, zmm16, 0, 0, 2, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k4, zmm20, 0, 0, 3, + selector1, selector2 ) + } + } + + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + zmm12 = _mm512_add_epi32( selector2, zmm12 ); + zmm16 = _mm512_add_epi32( selector3, zmm16 ); + zmm20 = _mm512_add_epi32( selector4, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + zmm8 = _mm512_max_epi32( selector1, zmm8 ); + zmm12 = _mm512_max_epi32( selector1, zmm12 ); + zmm16 = _mm512_max_epi32( selector1, zmm16 ); + zmm20 = _mm512_max_epi32( selector1, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + RELU_SCALE_OP_S32_AVX512( zmm8 ) + RELU_SCALE_OP_S32_AVX512( zmm12 ) + RELU_SCALE_OP_S32_AVX512( zmm16 ) + RELU_SCALE_OP_S32_AVX512( zmm20 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, y, x_tanh; + + GELU_TANH_S32_AVX512( zmm8, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm12, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm16, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm20, y, r, r2, x, + z, dn, x_tanh, selector1 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, y, x_erf; + + GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm12, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm16, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm20, y, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + + } + POST_OPS_CLIP_6x64: + { + __m512i min = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args2 ); + __m512i max = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args3 ); + + CLIP_S32_AVX512( zmm8, min, max ) + CLIP_S32_AVX512( zmm12, min, max ) + CLIP_S32_AVX512( zmm16, min, max ) + CLIP_S32_AVX512( zmm20, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6x64: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + CVT_MULRND_CVT32(zmm8, selector1, zero_point0 ); + CVT_MULRND_CVT32(zmm12, selector2, zero_point1 ); + CVT_MULRND_CVT32(zmm16, selector3, zero_point2 ); + CVT_MULRND_CVT32(zmm20, selector4, zero_point3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector1, 0, 0 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector2, 0, 1 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector3, 0, 2 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector4, 0, 3 ); + + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector1, 0, 0 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector2, 0, 1 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector3, 0, 2 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector4, 0, 3 ); + } + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + zmm12 = _mm512_add_epi32( selector2, zmm12 ); + zmm16 = _mm512_add_epi32( selector3, zmm16 ); + zmm20 = _mm512_add_epi32( selector4, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm12, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm16, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm20, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64_DISABLE: + { + if ( post_ops_attr.buf_downscale != NULL ) + { + CVT_STORE_S32_S8_MASK( zmm8, k1, 0, 0 ); + CVT_STORE_S32_S8_MASK( zmm12, k2, 0, 1 ); + CVT_STORE_S32_S8_MASK( zmm16, k3, 0, 2 ); + CVT_STORE_S32_S8_MASK( zmm20, k4, 0, 3 ); + } + else + { + _mm512_mask_storeu_epi32( c_use + ( 0*16 ), k1, zmm8 ); + _mm512_mask_storeu_epi32( c_use + ( 1*16 ), k2, zmm12 ); + _mm512_mask_storeu_epi32( c_use + ( 2*16 ), k3, zmm16 ); + _mm512_mask_storeu_epi32( c_use + ( 3*16 ), k4, zmm20 ); + } + } + + post_ops_attr.post_op_c_j += nr0; + post_ops_attr.b_sum_offset += nr0; + + } // jr loop + +} +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemv_n_kernel_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemv_n_kernel_amd512vnni.c new file mode 100644 index 0000000000..88921a8a03 --- /dev/null +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemv_n_kernel_amd512vnni.c @@ -0,0 +1,760 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../u8s8s32/lpgemm_s32_kern_macros.h" +#include "../u8s8s32/lpgemm_s32_memcpy_macros.h" + +#define LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, paddr, stride ) \ + zmm0 = _mm512_loadu_si512( paddr ); \ + zmm1 = _mm512_loadu_si512( paddr + stride ); \ + zmm2 = _mm512_loadu_si512( paddr + 2 * stride ); \ + zmm3 = _mm512_loadu_si512( paddr + 3 * stride ); \ + zmm0 = _mm512_add_epi8( zmm0, vec_uint8 ); \ + zmm1 = _mm512_add_epi8( zmm1, vec_uint8 ); \ + zmm2 = _mm512_add_epi8( zmm2, vec_uint8 ); \ + zmm3 = _mm512_add_epi8( zmm3, vec_uint8 ); + + +#define LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, \ + zmm3, k1, paddr, stride ) \ + zmm0 = _mm512_maskz_loadu_epi8( k1, paddr ); \ + zmm1 = _mm512_maskz_loadu_epi8( k1, paddr + stride ); \ + zmm2 = _mm512_maskz_loadu_epi8( k1, paddr + 2 * stride ); \ + zmm3 = _mm512_maskz_loadu_epi8( k1, paddr + 3 * stride ); \ + zmm0 = _mm512_maskz_add_epi8( k1, zmm0, vec_uint8 ); \ + zmm1 = _mm512_maskz_add_epi8( k1, zmm1, vec_uint8 ); \ + zmm2 = _mm512_maskz_add_epi8( k1, zmm2, vec_uint8 ); \ + zmm3 = _mm512_maskz_add_epi8( k1, zmm3, vec_uint8 ); \ + +#define LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, \ + zmm6, zmm0, zmm1, zmm2, zmm3 ) \ + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm0, zmm6 ); \ + zmm9 = _mm512_dpbusd_epi32( zmm9, zmm1, zmm6 ); \ + zmm10 = _mm512_dpbusd_epi32( zmm10, zmm2, zmm6 ); \ + zmm11 = _mm512_dpbusd_epi32( zmm11, zmm3, zmm6 ); + +#define LPGEMV_ZMM2XMM( zmm0, zmm1, zmm2, zmm3, \ + ymm0, ymm1, ymm2, ymm3, xmm0) \ + ymm0 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm0, 0x0), \ + _mm512_extracti32x8_epi32 (zmm0, 0x1)); \ + ymm1 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm1, 0x0), \ + _mm512_extracti32x8_epi32 (zmm1, 0x1)); \ + ymm0 = _mm256_hadd_epi32 (ymm0, ymm1); \ + ymm2 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm2, 0x0), \ + _mm512_extracti32x8_epi32 (zmm2, 0x1)); \ + ymm3 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm3, 0x0), \ + _mm512_extracti32x8_epi32 (zmm3, 0x1)); \ + ymm1 = _mm256_hadd_epi32 (ymm2, ymm3); \ + ymm0 = _mm256_hadd_epi32 (ymm0, ymm1); \ + xmm0 = _mm_add_epi32 ( _mm256_extracti128_si256 (ymm0, 0), \ + _mm256_extracti128_si256 (ymm0,1)); + +#define CVT_STORE_S32_S8_MASK(reg,mask,m_ind,n_ind) \ + _mm512_mask_cvtsepi32_storeu_epi8 \ + ( \ + ( int8_t* )post_ops_attr.buf_downscale + \ + ( post_ops_attr.rs_c_downscale * \ + ( post_ops_attr.post_op_c_i + m_ind ) ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ), \ + mask, reg \ + ); \ + +LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 + }; + + const int8_t *a_use = NULL; + const int8_t *b_use = NULL; + int32_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + uint8_t cvt_uint8 = 128; + __m512i vec_uint8 = _mm512_set1_epi8 (cvt_uint8); + + int32_t* bsumptr = post_ops_attr.b_col_sum_vec; + + for ( dim_t ir = 0; ir < m0; ir += MR ) + { + dim_t mr0 = bli_min( ( m0 - ir ), MR ); + dim_t k_iter = k/64; + dim_t k_rem = k & 0x3F; + + //Create load mask for k fringe + __mmask64 k1 = 0xFFFFFFFFFFFFFFFF; + if( k_rem ) + { + k1 = ( k1 >> ( 64 - k_rem ) ); + } + + // Create store mask for C for mr fringe + __mmask16 k2 = 0xFFFF; + if ( mr0 < MR ) + { + k2 = ( 0xFFFF >> ( MR - mr0 ) ); + } + + __m512i zmm0, zmm1, zmm2, zmm3, zmm6; + __m512i zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512i zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512i zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512i zmm29, zmm30, zmm31; + + __m256i ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6; + __m128i xmm0, xmm1, xmm2, xmm3; + + /* zero the accumulator registers */ + ZERO_ACC_ZMM_4_REG( zmm8, zmm9, zmm10, zmm11 ) + ZERO_ACC_ZMM_4_REG( zmm12, zmm13, zmm14, zmm15 ) + ZERO_ACC_ZMM_4_REG( zmm16, zmm17, zmm18, zmm19 ) + ZERO_ACC_ZMM_4_REG( zmm20, zmm21, zmm22, zmm23 ) + ZERO_ACC_XMM_4_REG( xmm0, xmm1, xmm2, xmm3 ) + + //update pointers + a_use = a + ir * rs_a; + b_use = b; + c_use = c + ir * rs_c; + + if( mr0 == MR ) + { + //Dot product kernel + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_si512( b_use ); + b_use += 64; + + //Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x64 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS( zmm28, zmm29, zmm30, + zmm31, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + + } // kloop + if( k_rem ) + { + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + //Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x64 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm28, zmm29, zmm30, + zmm31, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + } + + //Add the registers horizantally to get one + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + //compose outputs into one zmm to perform post-ops + zmm8 = _mm512_inserti32x4 ( zmm8, xmm0, 0 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm1, 1 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm2, 2 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm3, 3 ); + + zmm0 = _mm512_set1_epi32( *bsumptr ); + zmm8 = _mm512_sub_epi32( zmm8, zmm0 ); + + } + else + { + //Handle fringe cases when mr0 < MR + const int8_t *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + // Dot product for mfringe 8 + if ( mr0_use >= 8 ) + { + // Dot product kernel for mr0 == 8 + for( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+31] + zmm6 = _mm512_loadu_si512( b_use ); + // move b pointer to next 64 elements + b_use += 64; + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use -= ( 4 * rs_a ); + + //Perform FMA on two 4x64 block of A with 64x1 + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + a_use += 64; + } + + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + } + + // update pointers + mr0_use -= 8; + a_use = a_use_fringe + 8 * rs_a; + a_use_fringe = a_use; + b_use = b; + + // Horizontal add 8 zmm registers + // and get output into 2 xmm registers + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + + //insert xmm outputs into final output zmm8 reg + zmm8 = _mm512_inserti32x4( zmm8, xmm0, 0 ); + zmm8 = _mm512_inserti32x4( zmm8, xmm1, 1 ); + regidx = 2; + + } + + // Dot product for mfringe 4 + if ( mr0_use >= 4 ) + { + // Dot product kernel for mr0 == 8 + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + + // move b pointer to next 64 elements + b_use += 64; + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + // Perform FMA on 4x64 block of A with 64x1 + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + } + + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + + //insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) zmm8 = _mm512_inserti32x4( zmm8, xmm2, 0 ); + else zmm8 = _mm512_inserti32x4( zmm8, xmm2, 2 ); + regidx++; + } + + // Dot product for <= 3 + if ( mr0_use ) + { + // Dot product for m = 2 + if ( mr0_use >= 2 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + + // Load 2x64 elements from row0-row1 of A + zmm0 = _mm512_loadu_si512( a_use ); + zmm1 = _mm512_loadu_si512( a_use + rs_a ); + + zmm0 = _mm512_add_epi8( zmm0, vec_uint8 ); + zmm1 = _mm512_add_epi8( zmm1, vec_uint8 ); + + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm0, zmm6 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm1, zmm6 ); + + b_use += 64; // move b pointer to next 64 elements + a_use += 64; + } + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + zmm0 = _mm512_maskz_loadu_epi8( k1, a_use ); + zmm1 = _mm512_maskz_loadu_epi8( k1, a_use + rs_a ); + + zmm0 = _mm512_maskz_add_epi8( k1, zmm0, vec_uint8 ); + zmm1 = _mm512_maskz_add_epi8( k1, zmm1, vec_uint8 ); + + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm0, zmm6 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm1, zmm6 ); + } + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = b; + } + + // Dot product for m = 2 + if ( mr0_use == 1 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + zmm0 = _mm512_loadu_si512( a_use ); + zmm0 = _mm512_add_epi8( zmm0, vec_uint8 ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm0, zmm6 ); + b_use += 64; // move b pointer to next 64 elements + a_use += 64; + } + + if ( k_rem ) + { + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + zmm0 = _mm512_maskz_loadu_epi8( k1, a_use ); + zmm0 = _mm512_maskz_add_epi8( k1, zmm0, vec_uint8 ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm0, zmm6 ); + } + // When only fringe 1, + // update the registers to store in order + if ( !( mr0 & 0x2 ) ) zmm20 = zmm22; + } + + // Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + // insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 0 ); + } + else if( regidx == 1 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 1 ); + } + else if ( regidx == 2 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 2 ); + } + else + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 3 ); + } + } + + zmm0 = _mm512_set1_epi32( *bsumptr ); + zmm8 = _mm512_maskz_sub_epi32( k2, zmm8, zmm0 ); + + } + + //Scale accumulated output with alpha + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mullo_epi32( selector1, zmm8 ); + + if( beta != 0 ) + { + if( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + S8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0, + selector1, selector2 ) + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm512_cvtepi8_epi32 + ( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) ); + S32_BETA_FMA( zmm8, selector1, selector2 ); + } + } + else + { + if( rs_c == 1) + { + S32_S32_BETA_OP_NLT16F_MASK( c_use, k2, zmm8, 0, 0, 0, + selector1, selector2 ) + } + else + { + int32_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = c_use[ i * rs_c ]; + } + selector1 = _mm512_loadu_epi32( ctemp ); + S32_BETA_FMA( zmm8, selector1, selector2 ); + } + } + } + + // Post Ops + lpgemm_post_op *post_ops_list_temp = post_op; + + post_ops_attr.is_last_k = TRUE; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_set1_epi32( + *( ( int32_t* )post_ops_list_temp->op_args1) ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + zmm8 = _mm512_max_epi32( selector1, zmm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( + *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + RELU_SCALE_OP_S32_AVX512(zmm8) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, y, x_tanh; + GELU_TANH_S32_AVX512( zmm8, y, r, r2, x, + z, dn, x_tanh, selector1 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, y, x_erf; + + GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64: + { + __m512i min = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args2 ); + __m512i max = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args3 ); + + CLIP_S32_AVX512( zmm8, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6x64: + { + selector1 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + + CVT_MULRND_CVT32(zmm8, selector1, zero_point0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + S8_S32_MATRIX_ADD_LOAD( k2, selector1, 0, 0 ) + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_cvtepi8_epi32 + ( _mm_maskz_loadu_epi8( k2, ctemp ) ); + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + S32_S32_MATRIX_ADD_LOAD(k2, selector1, 0, 0 ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + else + { + int32_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_maskz_loadu_epi32( k2, ctemp ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) ); + + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, r, r2, z, dn, selector2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64_DISABLE: + { + // Case where the output C matrix is s8 (downscaled) and + // this is the final write for a given block within C. + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + CVT_STORE_S32_S8_MASK( zmm8, k2, 0, 0 ); + } + else + { + int8_t ctemp[16]; + + _mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2, zmm8 ); + + for (dim_t i = 0; i < mr0; i++) + { + *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + } + else + { + if(rs_c == 1) + { + _mm512_mask_storeu_epi32(c_use, k2, zmm8); + } + else + { + // Store ZMM8 into ctemp buffer and store back + // element by element into output buffer at strides + int32_t ctemp[16]; + _mm512_mask_storeu_epi32(ctemp, k2, zmm8); + for (dim_t i = 0; i < mr0; i++) + { + c_use[i * rs_c] = ctemp[i]; + } + } + } + post_ops_attr.post_op_c_i += MR; + } + } +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/silu_avx512.h b/kernels/zen4/lpgemm/silu_avx512.h new file mode 100644 index 0000000000..3250dfecd9 --- /dev/null +++ b/kernels/zen4/lpgemm/silu_avx512.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_LPGEMM_SWISH_AVX512_H +#define AOCL_LPGEMM_SWISH_AVX512_H + +// SiLU(in_reg) = in_reg / (1 + exp(-1 * al * in_reg)). +// in_reg and al are expected to contain float values. +#define SWISH_F32_AVX512_DEF(in_reg, al, al_in, r, r2, z, dn, ex_out) \ + al_in = _mm512_fnmadd_ps( in_reg, al, _mm512_setzero_ps() ); \ + EXPF_AVX512(al_in, r, r2, z, dn, ex_out); \ + ex_out = ( __m512i )_mm512_add_ps( ( __m512 )ex_out, _mm512_set1_ps( 1 ) ); \ + in_reg = _mm512_div_ps( in_reg, ( __m512 )ex_out ); \ + +#endif // AOCL_LPGEMM_SWISH_AVX512_H diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index 32bfc2c8af..8a34499161 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,9 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) &&POST_OPS_GELU_TANH_6x64, &&POST_OPS_GELU_ERF_6x64, &&POST_OPS_CLIP_6x64, - &&POST_OPS_DOWNSCALE_6x64 + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 }; const dim_t MR = 6; @@ -154,16 +156,16 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) } // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); _mm_prefetch( a, _MM_HINT_T0 ); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) @@ -890,39 +892,68 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); // int8_t zero point value. - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -996,6 +1027,138 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) // c[5, 48-63] CVT_MULRND_CVT32(c_int32_5p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + + // c[5:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + + // c[5:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 48-63] + SWISH_S32_AVX512(c_int32_4p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 32-47] + SWISH_S32_AVX512(c_int32_5p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 48-63] + SWISH_S32_AVX512(c_int32_5p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x64_DISABLE: diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c index 23393cad4f..73f2f97405 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,20 +53,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) &&POST_OPS_GELU_TANH_5x64, &&POST_OPS_GELU_ERF_5x64, &&POST_OPS_CLIP_5x64, - &&POST_OPS_DOWNSCALE_5x64 + &&POST_OPS_DOWNSCALE_5x64, + &&POST_OPS_MATRIX_ADD_5x64, + &&POST_OPS_SWISH_5x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -737,37 +739,68 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -829,6 +862,120 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) // c[4, 48-63] CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + + // c[4:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 48-63] + SWISH_S32_AVX512(c_int32_4p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x64_DISABLE: @@ -979,20 +1126,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) &&POST_OPS_GELU_TANH_4x64, &&POST_OPS_GELU_ERF_4x64, &&POST_OPS_CLIP_4x64, - &&POST_OPS_DOWNSCALE_4x64 + &&POST_OPS_DOWNSCALE_4x64, + &&POST_OPS_MATRIX_ADD_4x64, + &&POST_OPS_SWISH_4x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -1015,16 +1164,29 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) __m512i c_int32_3p2 = _mm512_setzero_epi32(); __m512i c_int32_3p3 = _mm512_setzero_epi32(); - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + //gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - + b0 = _mm512_shuffle_epi8(b0, dsmask); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b1 = _mm512_shuffle_epi8(b1, dsmask); b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b2 = _mm512_shuffle_epi8(b2, dsmask); b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b3 = _mm512_shuffle_epi8(b3, dsmask); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -1073,7 +1235,6 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - k_partial_pieces ) ); b0 = _mm512_loadu_si512( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - // Broadcast a[0,kr:kr+4]. a_kfringe_buf = _mm_maskz_loadu_epi8 ( @@ -1548,37 +1709,68 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1628,6 +1820,102 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) // c[3, 48-63] CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + + // c[3:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 48-63] + SWISH_S32_AVX512(c_int32_3p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x64_DISABLE: @@ -1754,20 +2042,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) &&POST_OPS_GELU_TANH_3x64, &&POST_OPS_GELU_ERF_3x64, &&POST_OPS_CLIP_3x64, - &&POST_OPS_DOWNSCALE_3x64 + &&POST_OPS_DOWNSCALE_3x64, + &&POST_OPS_MATRIX_ADD_3x64, + &&POST_OPS_SWISH_3x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -1785,16 +2075,29 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) __m512i c_int32_2p2 = _mm512_setzero_epi32(); __m512i c_int32_2p3 = _mm512_setzero_epi32(); + // gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - + b0 = _mm512_shuffle_epi8(b0, dsmask); // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b1 = _mm512_shuffle_epi8(b1, dsmask); b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b2 = _mm512_shuffle_epi8(b2, dsmask); b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b3 = _mm512_shuffle_epi8(b3, dsmask); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -2208,37 +2511,68 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2276,6 +2610,84 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) // c[2, 48-63] CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + + // c[2:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 48-63] + SWISH_S32_AVX512(c_int32_2p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x64_DISABLE: @@ -2378,20 +2790,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) &&POST_OPS_GELU_TANH_2x64, &&POST_OPS_GELU_ERF_2x64, &&POST_OPS_CLIP_2x64, - &&POST_OPS_DOWNSCALE_2x64 + &&POST_OPS_DOWNSCALE_2x64, + &&POST_OPS_MATRIX_ADD_2x64, + &&POST_OPS_SWISH_2x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -2404,16 +2818,29 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) __m512i c_int32_1p2 = _mm512_setzero_epi32(); __m512i c_int32_1p3 = _mm512_setzero_epi32(); + // gcc compiler (atleast 11.2 to 13.1) avoid loading B into + // registers while generating the code. A dummy shuffle instruction + // is used on b data to explicitly specify to gcc compiler + // b data needs to be kept in registers to reuse across FMA's + __m512i dsmask = _mm512_set_epi64( + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100, + 0x0F0E0D0C0B0A0908, 0x0706050403020100); + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { b0 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - + b0 = _mm512_shuffle_epi8(b0, dsmask); // Broadcast a[0,kr:kr+4]. a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_si512( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_si512( b + (rs_b * kr) + (cs_b * 1)); + b1 = _mm512_shuffle_epi8( b1, dsmask); + b2 = _mm512_loadu_si512( b + (rs_b * kr) + (cs_b * 2)); + b2 = _mm512_shuffle_epi8( b2, dsmask); + b3 = _mm512_loadu_si512( b + (rs_b * kr) + (cs_b * 3)); + b3 = _mm512_shuffle_epi8( b3, dsmask); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -2718,37 +3145,68 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2774,6 +3232,66 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) // c[1, 48-63] CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + + // c[1:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 48-63] + SWISH_S32_AVX512(c_int32_1p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x64_DISABLE: @@ -2852,20 +3370,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) &&POST_OPS_GELU_TANH_1x64, &&POST_OPS_GELU_ERF_1x64, &&POST_OPS_CLIP_1x64, - &&POST_OPS_DOWNSCALE_1x64 + &&POST_OPS_DOWNSCALE_1x64, + &&POST_OPS_MATRIX_ADD_1x64, + &&POST_OPS_SWISH_1x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); + __m512i b3 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -3076,37 +3596,68 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x64: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - __m128i zero_point3 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3120,6 +3671,48 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) // c[0, 48-63] CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S8_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47,48-63] + S32_S32_MATRIX_ADD_4COL(selector1,selector2,a_int32_0,a_int32_1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 48-63] + SWISH_S32_AVX512(c_int32_0p3, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x64_DISABLE: diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 3dcb7eed07..019083ad15 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) &&POST_OPS_GELU_TANH_5xLT16, &&POST_OPS_GELU_ERF_5xLT16, &&POST_OPS_CLIP_5xLT16, - &&POST_OPS_DOWNSCALE_5xLT16 + &&POST_OPS_DOWNSCALE_5xLT16, + &&POST_OPS_MATRIX_ADD_5xLT16, + &&POST_OPS_SWISH_5xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -229,23 +231,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, 0, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, 0, 3, 0, \ selector1, selector2); // c[4,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_4p0, 0, 4, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_4p0, 0, 4, 0, \ selector1, selector2); } } @@ -391,23 +393,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -424,6 +444,76 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) // c[4, 0-15] CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5xLT16_DISABLE: @@ -485,7 +575,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) &&POST_OPS_GELU_TANH_4xLT16, &&POST_OPS_GELU_ERF_4xLT16, &&POST_OPS_CLIP_4xLT16, - &&POST_OPS_DOWNSCALE_4xLT16 + &&POST_OPS_DOWNSCALE_4xLT16, + &&POST_OPS_MATRIX_ADD_4xLT16, + &&POST_OPS_SWISH_4xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -634,19 +726,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, 0, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, 0, 3, 0, \ selector1, selector2); } } @@ -774,23 +866,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -804,6 +914,67 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) // c[3, 0-15] CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xLT16_DISABLE: @@ -859,7 +1030,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) &&POST_OPS_GELU_TANH_3xLT16, &&POST_OPS_GELU_ERF_3xLT16, &&POST_OPS_CLIP_3xLT16, - &&POST_OPS_DOWNSCALE_3xLT16 + &&POST_OPS_DOWNSCALE_3xLT16, + &&POST_OPS_MATRIX_ADD_3xLT16, + &&POST_OPS_SWISH_3xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -981,15 +1154,15 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, 0, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, 0, 2, 0, \ selector1, selector2); } } @@ -1099,23 +1272,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -1126,6 +1317,58 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) // c[2, 0-15] CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3xLT16_DISABLE: @@ -1175,7 +1418,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) &&POST_OPS_GELU_TANH_2xLT16, &&POST_OPS_GELU_ERF_2xLT16, &&POST_OPS_CLIP_2xLT16, - &&POST_OPS_DOWNSCALE_2xLT16 + &&POST_OPS_DOWNSCALE_2xLT16, + &&POST_OPS_MATRIX_ADD_2xLT16, + &&POST_OPS_SWISH_2xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1270,11 +1515,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, 0, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, 0, 1, 0, \ selector1, selector2); } } @@ -1366,23 +1611,41 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -1390,6 +1653,49 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) // c[1, 0-15] CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xLT16_DISABLE: @@ -1433,7 +1739,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) &&POST_OPS_GELU_TANH_1xLT16, &&POST_OPS_GELU_ERF_1xLT16, &&POST_OPS_CLIP_1xLT16, - &&POST_OPS_DOWNSCALE_1xLT16 + &&POST_OPS_DOWNSCALE_1xLT16, + &&POST_OPS_MATRIX_ADD_1xLT16, + &&POST_OPS_SWISH_1xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1501,7 +1809,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, 0, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, 0, 0, 0, \ selector1, selector2); } } @@ -1575,27 +1883,79 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xLT16_DISABLE: @@ -1633,7 +1993,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) &&POST_OPS_GELU_TANH_5x16, &&POST_OPS_GELU_ERF_5x16, &&POST_OPS_CLIP_5x16, - &&POST_OPS_DOWNSCALE_5x16 + &&POST_OPS_DOWNSCALE_5x16, + &&POST_OPS_MATRIX_ADD_5x16, + &&POST_OPS_SWISH_5x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1952,16 +2314,33 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1980,15 +2359,84 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } -POST_OPS_5x16_DISABLE: - ; - - if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) +POST_OPS_MATRIX_ADD_5x16: { - // Generate a mask16 of all 1's. - selector1 = _mm512_setzero_epi32(); - selector2 = _mm512_set1_epi32( 10 ); - __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector1, selector2 ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + + if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) + { + // Generate a mask16 of all 1's. + selector1 = _mm512_setzero_epi32(); + selector2 = _mm512_set1_epi32( 10 ); + __mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector1, selector2 ); // Store the results in downscaled type (int8 instead of int32). // c[0,0-15] @@ -2038,7 +2486,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) &&POST_OPS_GELU_TANH_4x16, &&POST_OPS_GELU_ERF_4x16, &&POST_OPS_CLIP_4x16, - &&POST_OPS_DOWNSCALE_4x16 + &&POST_OPS_DOWNSCALE_4x16, + &&POST_OPS_MATRIX_ADD_4x16, + &&POST_OPS_SWISH_4x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2310,16 +2760,33 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2333,6 +2800,66 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) // c[3, 0-15] CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -2387,7 +2914,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) &&POST_OPS_GELU_TANH_3x16, &&POST_OPS_GELU_ERF_3x16, &&POST_OPS_CLIP_3x16, - &&POST_OPS_DOWNSCALE_3x16 + &&POST_OPS_DOWNSCALE_3x16, + &&POST_OPS_MATRIX_ADD_3x16, + &&POST_OPS_SWISH_3x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2612,16 +3141,33 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2632,6 +3178,57 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) // c[2, 0-15] CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x16_DISABLE: @@ -2680,7 +3277,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) &&POST_OPS_GELU_TANH_2x16, &&POST_OPS_GELU_ERF_2x16, &&POST_OPS_CLIP_2x16, - &&POST_OPS_DOWNSCALE_2x16 + &&POST_OPS_DOWNSCALE_2x16, + &&POST_OPS_MATRIX_ADD_2x16, + &&POST_OPS_SWISH_2x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2858,16 +3457,33 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2875,6 +3491,48 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) // c[1, 0-15] CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -2917,7 +3575,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) &&POST_OPS_GELU_TANH_1x16, &&POST_OPS_GELU_ERF_1x16, &&POST_OPS_CLIP_1x16, - &&POST_OPS_DOWNSCALE_1x16 + &&POST_OPS_DOWNSCALE_1x16, + &&POST_OPS_MATRIX_ADD_1x16, + &&POST_OPS_SWISH_1x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3048,20 +3708,70 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -3098,17 +3808,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) &&POST_OPS_GELU_TANH_5x32, &&POST_OPS_GELU_ERF_5x32, &&POST_OPS_CLIP_5x32, - &&POST_OPS_DOWNSCALE_5x32 + &&POST_OPS_DOWNSCALE_5x32, + &&POST_OPS_MATRIX_ADD_5x32, + &&POST_OPS_SWISH_5x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -3539,23 +4251,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -3587,6 +4320,90 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) // c[4, 16-31] CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x32_DISABLE: @@ -3677,17 +4494,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) &&POST_OPS_GELU_TANH_4x32, &&POST_OPS_GELU_ERF_4x32, &&POST_OPS_CLIP_4x32, - &&POST_OPS_DOWNSCALE_4x32 + &&POST_OPS_DOWNSCALE_4x32, + &&POST_OPS_MATRIX_ADD_4x32, + &&POST_OPS_SWISH_4x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -4041,55 +4860,148 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) // c[2, 16-31] CLIP_S32_AVX512(c_int32_2p1, min, max) - // c[3, 0-15] - CLIP_S32_AVX512(c_int32_3p0, min, max) + // c[3, 0-15] + CLIP_S32_AVX512(c_int32_3p0, min, max) + + // c[3, 16-31] + CLIP_S32_AVX512(c_int32_3p1, min, max) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + // c[0, 0-15] + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + + // c[0, 16-31] + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + + // c[1, 0-15] + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + + // c[1, 16-31] + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + + // c[2, 0-15] + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + + // c[2, 16-31] + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + + // c[3, 0-15] + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + + // c[3, 16-31] + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); - // c[3, 16-31] - CLIP_S32_AVX512(c_int32_3p1, min, max) + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - -POST_OPS_DOWNSCALE_4x32: +POST_OPS_SWISH_4x32: { selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4169,17 +5081,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) &&POST_OPS_GELU_TANH_3x32, &&POST_OPS_GELU_ERF_3x32, &&POST_OPS_CLIP_3x32, - &&POST_OPS_DOWNSCALE_3x32 + &&POST_OPS_DOWNSCALE_3x32, + &&POST_OPS_MATRIX_ADD_3x32, + &&POST_OPS_SWISH_3x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -4472,23 +5386,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -4508,6 +5443,66 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) // c[2, 16-31] CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x32_DISABLE: @@ -4574,17 +5569,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) &&POST_OPS_GELU_TANH_2x32, &&POST_OPS_GELU_ERF_2x32, &&POST_OPS_CLIP_2x32, - &&POST_OPS_DOWNSCALE_2x32 + &&POST_OPS_DOWNSCALE_2x32, + &&POST_OPS_MATRIX_ADD_2x32, + &&POST_OPS_SWISH_2x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -4808,23 +5805,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -4838,6 +5856,54 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) // c[1, 16-31] CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -4892,17 +5958,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) &&POST_OPS_GELU_TANH_1x32, &&POST_OPS_GELU_ERF_1x32, &&POST_OPS_CLIP_1x32, - &&POST_OPS_DOWNSCALE_1x32 + &&POST_OPS_DOWNSCALE_1x32, + &&POST_OPS_MATRIX_ADD_1x32, + &&POST_OPS_SWISH_1x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -5057,23 +6125,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -5081,6 +6170,42 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) // c[0, 16-31] CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: @@ -5123,18 +6248,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) &&POST_OPS_GELU_TANH_5x48, &&POST_OPS_GELU_ERF_5x48, &&POST_OPS_CLIP_5x48, - &&POST_OPS_DOWNSCALE_5x48 + &&POST_OPS_DOWNSCALE_5x48, + &&POST_OPS_MATRIX_ADD_5x48, + &&POST_OPS_SWISH_5x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -5680,75 +6807,199 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_5x48: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + // c[0, 0-15] + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + + // c[0, 16-31] + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + + // c[0, 32-47] + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + + // c[1, 0-15] + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + + // c[1, 16-31] + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + + // c[1, 32-47] + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + + // c[2, 0-15] + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + + // c[2, 16-31] + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + + // c[2, 32-47] + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + + // c[3, 0-15] + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + + // c[3, 16-31] + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + + // c[3, 32-47] + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + + // c[4, 0-15] + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); + + // c[4, 16-31] + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + + // c[4, 32-47] + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_5x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_5x48: { selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -5870,18 +7121,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) &&POST_OPS_GELU_TANH_4x48, &&POST_OPS_GELU_ERF_4x48, &&POST_OPS_CLIP_4x48, - &&POST_OPS_DOWNSCALE_4x48 + &&POST_OPS_DOWNSCALE_4x48, + &&POST_OPS_MATRIX_ADD_4x48, + &&POST_OPS_SWISH_4x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -6336,30 +7589,55 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_4x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -6397,6 +7675,90 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) // c[3, 32-47] CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_4x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_4x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x48_DISABLE: @@ -6499,18 +7861,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) &&POST_OPS_GELU_TANH_3x48, &&POST_OPS_GELU_ERF_3x48, &&POST_OPS_CLIP_3x48, - &&POST_OPS_DOWNSCALE_3x48 + &&POST_OPS_DOWNSCALE_3x48, + &&POST_OPS_MATRIX_ADD_3x48, + &&POST_OPS_SWISH_3x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -6874,30 +8238,55 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_3x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -6926,6 +8315,75 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) // c[2, 32-47] CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_3x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_3x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x48_DISABLE: @@ -7010,18 +8468,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) &&POST_OPS_GELU_TANH_2x48, &&POST_OPS_GELU_ERF_2x48, &&POST_OPS_CLIP_2x48, - &&POST_OPS_DOWNSCALE_2x48 + &&POST_OPS_DOWNSCALE_2x48, + &&POST_OPS_MATRIX_ADD_2x48, + &&POST_OPS_SWISH_2x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -7294,30 +8754,55 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_2x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -7337,6 +8822,60 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) // c[1, 32-47] CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_2x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_2x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x48_DISABLE: @@ -7403,18 +8942,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) &&POST_OPS_GELU_TANH_1x48, &&POST_OPS_GELU_ERF_1x48, &&POST_OPS_CLIP_1x48, - &&POST_OPS_DOWNSCALE_1x48 + &&POST_OPS_DOWNSCALE_1x48, + &&POST_OPS_MATRIX_ADD_1x48, + &&POST_OPS_SWISH_1x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); @@ -7596,30 +9137,55 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_1x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -7630,6 +9196,45 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) // c[0, 32-47] CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_1x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_1x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x48_DISABLE: diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c index bfe3fb6ce1..8d9c377637 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -59,7 +59,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) &&POST_OPS_GELU_TANH_12xLT16, &&POST_OPS_GELU_ERF_12xLT16, &&POST_OPS_CLIP_12xLT16, - &&POST_OPS_DOWNSCALE_12xLT16 + &&POST_OPS_DOWNSCALE_12xLT16, + &&POST_OPS_MATRIX_ADD_12xLT16, + &&POST_OPS_SWISH_12xLT16 }; dim_t MR = 12; dim_t m_full_pieces = m0 / MR; @@ -70,21 +72,21 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; - __m512i a_int32_4; - __m512i a_int32_5; - __m512i a_int32_6; - __m512i a_int32_7; - __m512i a_int32_8; - __m512i a_int32_9; - __m512i a_int32_10; - __m512i a_int32_11; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); + __m512i a_int32_4 = _mm512_setzero_epi32(); + __m512i a_int32_5 = _mm512_setzero_epi32(); + __m512i a_int32_6 = _mm512_setzero_epi32(); + __m512i a_int32_7 = _mm512_setzero_epi32(); + __m512i a_int32_8 = _mm512_setzero_epi32(); + __m512i a_int32_9 = _mm512_setzero_epi32(); + __m512i a_int32_10 = _mm512_setzero_epi32(); + __m512i a_int32_11 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -456,51 +458,51 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, ir, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, ir, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, ir, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, ir, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, ir, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, ir, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, ir, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, ir, 3, 0, \ selector1, selector2); // c[4,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_4p0, ir, 4, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_4p0, ir, 4, 0, \ selector1, selector2); // c[5,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_5p0, ir, 5, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_5p0, ir, 5, 0, \ selector1, selector2); // c[6,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_6p0, ir, 6, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_6p0, ir, 6, 0, \ selector1, selector2); // c[7,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_7p0, ir, 7, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_7p0, ir, 7, 0, \ selector1, selector2); // c[8,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_8p0, ir, 8, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_8p0, ir, 8, 0, \ selector1, selector2); // c[9,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_9p0, ir, 9, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_9p0, ir, 9, 0, \ selector1, selector2); // c[10,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_10p0, ir, 10, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_10p0, ir, 10, 0, \ selector1, selector2); // c[11,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_11p0, ir, 11, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_11p0, ir, 11, 0, \ selector1, selector2); } } @@ -772,23 +774,41 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_12xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -826,6 +846,139 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) // c[11, 0-15] CVT_MULRND_CVT32_LT16(c_int32_11p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_12xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + + // c[6:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,6); + + // c[7:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,7); + + // c[8:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,8); + + // c[9:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,9); + + // c[10:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,10); + + // c[11:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,11); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + + // c[6:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,6); + + // c[7:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,7); + + // c[8:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,8); + + // c[9:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,9); + + // c[10:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,10); + + // c[11:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,11); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_12xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[6, 0-15] + SWISH_S32_AVX512(c_int32_6p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[7, 0-15] + SWISH_S32_AVX512(c_int32_7p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[8, 0-15] + SWISH_S32_AVX512(c_int32_8p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[9, 0-15] + SWISH_S32_AVX512(c_int32_9p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[10, 0-15] + SWISH_S32_AVX512(c_int32_10p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[11, 0-15] + SWISH_S32_AVX512(c_int32_11p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_12xLT16_DISABLE: @@ -983,7 +1136,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16) &&POST_OPS_GELU_TANH_12x16, &&POST_OPS_GELU_ERF_12x16, &&POST_OPS_CLIP_12x16, - &&POST_OPS_DOWNSCALE_12x16 + &&POST_OPS_DOWNSCALE_12x16, + &&POST_OPS_MATRIX_ADD_12x16, + &&POST_OPS_SWISH_12x16 }; dim_t MR = 12; dim_t m_full_pieces = m0 / MR; @@ -994,21 +1149,21 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; - __m512i a_int32_4; - __m512i a_int32_5; - __m512i a_int32_6; - __m512i a_int32_7; - __m512i a_int32_8; - __m512i a_int32_9; - __m512i a_int32_10; - __m512i a_int32_11; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); + __m512i a_int32_4 = _mm512_setzero_epi32(); + __m512i a_int32_5 = _mm512_setzero_epi32(); + __m512i a_int32_6 = _mm512_setzero_epi32(); + __m512i a_int32_7 = _mm512_setzero_epi32(); + __m512i a_int32_8 = _mm512_setzero_epi32(); + __m512i a_int32_9 = _mm512_setzero_epi32(); + __m512i a_int32_10 = _mm512_setzero_epi32(); + __m512i a_int32_11 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -1664,16 +1819,33 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_12x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1713,6 +1885,138 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_12x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,5); + + // c[6:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,6); + + // c[7:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,7); + + // c[8:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,8); + + // c[9:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,9); + + // c[10:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,10); + + // c[11:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,11); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,5); + + // c[6:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,6); + + // c[7:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,7); + + // c[8:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,8); + + // c[9:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,9); + + // c[10:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,10); + + // c[11:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,11); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_12x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[6, 0-15] + SWISH_S32_AVX512(c_int32_6p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[7, 0-15] + SWISH_S32_AVX512(c_int32_7p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[8, 0-15] + SWISH_S32_AVX512(c_int32_8p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[9, 0-15] + SWISH_S32_AVX512(c_int32_9p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[10, 0-15] + SWISH_S32_AVX512(c_int32_10p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[11, 0-15] + SWISH_S32_AVX512(c_int32_11p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_12x16_DISABLE: ; @@ -1831,7 +2135,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32) &&POST_OPS_GELU_TANH_9x32, &&POST_OPS_GELU_ERF_9x32, &&POST_OPS_CLIP_9x32, - &&POST_OPS_DOWNSCALE_9x32 + &&POST_OPS_DOWNSCALE_9x32, + &&POST_OPS_MATRIX_ADD_9x32, + &&POST_OPS_SWISH_9x32 }; dim_t MR = 9; dim_t m_full_pieces = m0 / MR; @@ -1842,14 +2148,14 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); __m512i selector1; __m512i selector2; @@ -2570,23 +2876,44 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_9x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2644,6 +2971,138 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_9x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,5); + + // c[6:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,6); + + // c[7:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,7); + + // c[8:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,8); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,5); + + // c[6:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,6); + + // c[7:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,7); + + // c[8:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,8); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_9x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[6, 0-15] + SWISH_S32_AVX512(c_int32_6p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[6, 16-31] + SWISH_S32_AVX512(c_int32_6p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[7, 0-15] + SWISH_S32_AVX512(c_int32_7p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[7, 16-31] + SWISH_S32_AVX512(c_int32_7p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[8, 0-15] + SWISH_S32_AVX512(c_int32_8p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[8, 16-31] + SWISH_S32_AVX512(c_int32_8p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_9x32_DISABLE: ; diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c index f3574e5dc0..a5ed7e6b1f 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,7 +53,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) &&POST_OPS_GELU_TANH_6xLT16, &&POST_OPS_GELU_ERF_6xLT16, &&POST_OPS_CLIP_6xLT16, - &&POST_OPS_DOWNSCALE_6xLT16 + &&POST_OPS_DOWNSCALE_6xLT16, + &&POST_OPS_MATRIX_ADD_6xLT16, + &&POST_OPS_SWISH_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -64,15 +66,15 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; - __m512i a_int32_4; - __m512i a_int32_5; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); + __m512i a_int32_4 = _mm512_setzero_epi32(); + __m512i a_int32_5 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -278,27 +280,27 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); // c[0,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_0p0, ir, 0, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_0p0, ir, 0, 0, \ selector1, selector2); // c[1,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_1p0, ir, 1, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_1p0, ir, 1, 0, \ selector1, selector2); // c[2,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_2p0, ir, 2, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_2p0, ir, 2, 0, \ selector1, selector2); // c[3,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_3p0, ir, 3, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_3p0, ir, 3, 0, \ selector1, selector2); // c[4,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_4p0, ir, 4, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_4p0, ir, 4, 0, \ selector1, selector2); // c[5,0-15] - S32_S32_BETA_OP_NLT16F_MASK(load_mask, c_int32_5p0, ir, 5, 0, \ + S32_S32_BETA_OP_NLT16F_MASK(c, load_mask, c_int32_5p0, ir, 5, 0, \ selector1, selector2); } } @@ -462,23 +464,41 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6xLT16: { // Typecast without data modification, safe operation. __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); - selector1 = _mm512_maskz_loadu_epi32 - ( - load_mask, - ( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j ) - ); - __m128i zero_point = _mm_maskz_loadu_epi8 - ( - load_mask, - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j ) - ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = _mm512_maskz_loadu_epi32 + ( + load_mask, + ( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j ) + ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); @@ -498,6 +518,85 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) // c[5, 0-15] CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1,zero_point); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_MATRIX_ADD_6xLT16: + { + __mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_rem ) ); + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL_PAR(load_mask,selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6xLT16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xLT16_DISABLE: @@ -637,7 +736,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) &&POST_OPS_GELU_TANH_6x16, &&POST_OPS_GELU_ERF_6x16, &&POST_OPS_CLIP_6x16, - &&POST_OPS_DOWNSCALE_6x16 + &&POST_OPS_DOWNSCALE_6x16, + &&POST_OPS_MATRIX_ADD_6x16, + &&POST_OPS_SWISH_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -648,15 +749,15 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; + __m512i b0 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; - __m512i a_int32_4; - __m512i a_int32_5; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); + __m512i a_int32_4 = _mm512_setzero_epi32(); + __m512i a_int32_5 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -1026,16 +1127,33 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x16: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1057,6 +1175,84 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x16: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S8_S32_MATRIX_ADD_1COL(selector1,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,0); + + // c[1:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,1); + + // c[2:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,2); + + // c[3:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,3); + + // c[4:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,4); + + // c[5:0-15] + S32_S32_MATRIX_ADD_1COL(selector1,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x16: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16_DISABLE: ; @@ -1194,7 +1390,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) &&POST_OPS_GELU_TANH_6x32, &&POST_OPS_GELU_ERF_6x32, &&POST_OPS_CLIP_6x32, - &&POST_OPS_DOWNSCALE_6x32 + &&POST_OPS_DOWNSCALE_6x32, + &&POST_OPS_MATRIX_ADD_6x32, + &&POST_OPS_SWISH_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1205,16 +1403,16 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - __m512i a_int32_2; - __m512i a_int32_3; - __m512i a_int32_4; - __m512i a_int32_5; + __m512i a_int32_0 = _mm512_setzero_epi32(); + __m512i a_int32_1 = _mm512_setzero_epi32(); + __m512i a_int32_2 = _mm512_setzero_epi32(); + __m512i a_int32_3 = _mm512_setzero_epi32(); + __m512i a_int32_4 = _mm512_setzero_epi32(); + __m512i a_int32_5 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -1725,23 +1923,44 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x32: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -1781,6 +2000,102 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x32: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S8_S32_MATRIX_ADD_2COL(selector1,selector2,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,0); + + // c[1:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,1); + + // c[2:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,2); + + // c[3:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,3); + + // c[4:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,4); + + // c[5:0-15,16-31] + S32_S32_MATRIX_ADD_2COL(selector1,selector2,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x32: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x32_DISABLE: ; @@ -1954,7 +2269,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) &&POST_OPS_GELU_TANH_6x48, &&POST_OPS_GELU_ERF_6x48, &&POST_OPS_CLIP_6x48, - &&POST_OPS_DOWNSCALE_6x48 + &&POST_OPS_DOWNSCALE_6x48, + &&POST_OPS_MATRIX_ADD_6x48, + &&POST_OPS_SWISH_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1965,12 +2282,12 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) dim_t k_partial_pieces = k0 % 4; // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; + __m512i b0 = _mm512_setzero_epi32(); + __m512i b1 = _mm512_setzero_epi32(); + __m512i b2 = _mm512_setzero_epi32(); // A matrix storage. - __m512i a_int32_0; + __m512i a_int32_0 = _mm512_setzero_epi32(); for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { @@ -2615,30 +2932,55 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } - POST_OPS_DOWNSCALE_6x48: { - selector1 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ); - __m128i zero_point0 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - __m128i zero_point1 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - __m128i zero_point2 = - _mm_loadu_si128( ( __m128i const* ) - ( ( int8_t* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + a_int32_0 = + ( __m512i )_mm512_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( _mm512_setzero_si512() ); + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } // c[0, 0-15] CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); @@ -2696,6 +3038,120 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_MATRIX_ADD_6x48: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + + // c[5:0-15,16-31,32-47] + S8_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,5); + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + // c[0:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,0); + + // c[1:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,1); + + // c[2:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,2); + + // c[3:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,3); + + // c[4:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,4); + + // c[5:0-15,16-31,32-47] + S32_S32_MATRIX_ADD_3COL(selector1,selector2,a_int32_0,5); + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_SWISH_6x48: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + // c[0, 0-15] + SWISH_S32_AVX512(c_int32_0p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 16-31] + SWISH_S32_AVX512(c_int32_0p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[0, 32-47] + SWISH_S32_AVX512(c_int32_0p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 0-15] + SWISH_S32_AVX512(c_int32_1p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 16-31] + SWISH_S32_AVX512(c_int32_1p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[1, 32-47] + SWISH_S32_AVX512(c_int32_1p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 0-15] + SWISH_S32_AVX512(c_int32_2p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 16-31] + SWISH_S32_AVX512(c_int32_2p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[2, 32-47] + SWISH_S32_AVX512(c_int32_2p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 0-15] + SWISH_S32_AVX512(c_int32_3p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 16-31] + SWISH_S32_AVX512(c_int32_3p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[3, 32-47] + SWISH_S32_AVX512(c_int32_3p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 0-15] + SWISH_S32_AVX512(c_int32_4p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 16-31] + SWISH_S32_AVX512(c_int32_4p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[4, 32-47] + SWISH_S32_AVX512(c_int32_4p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 0-15] + SWISH_S32_AVX512(c_int32_5p0, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 16-31] + SWISH_S32_AVX512(c_int32_5p1, fl_reg, al, al_in, r, r2, z, dn, selector2); + + // c[5, 32-47] + SWISH_S32_AVX512(c_int32_5p2, fl_reg, al, al_in, r, r2, z, dn, selector2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48_DISABLE: ; diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c index cdaf576172..475b74e549 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,17 @@ #define MR 6 #define NR 64 +void packa_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + void packa_m5_k64_u8s8s32o32 ( uint8_t* pack_a_buffer_u8s8s32o32, @@ -80,6 +91,44 @@ void packa_m1_k64_u8s8s32o32 const dim_t KC ); +void packa_mr16_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +void packa_u8s8s32os32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + if( cs == 1 ) + { + packa_k64_u8s8s32o32 + ( pack_a_buffer_u8s8s32o32, a, rs, MC, KC, rs_a, cs_a ); + } + else + { + packa_mr16_u8s8s32o32_col_major + ( pack_a_buffer_u8s8s32o32, a, rs, cs, MC, KC, rs_a, cs_a ); + } +} + + +// Row Major Packing in blocks of MRxKC // TODO: k fringe till k=4, k%4=0 and padding to make k'%4 = 0 if k%4 != 0 originally. void packa_k64_u8s8s32o32 ( @@ -531,4 +580,1044 @@ void packa_m1_k64_u8s8s32o32 _mm512_storeu_si512( pack_a_buffer_u8s8s32o32 + ( ( kr * 1 ) + ( 0 ) ), a0 ); } } + +#define SET_REGISTERS_ZERO \ + a_reg[0] = _mm_setzero_si128(); \ + a_reg[1] = _mm_setzero_si128(); \ + a_reg[2] = _mm_setzero_si128(); \ + a_reg[3] = _mm_setzero_si128(); \ + a_reg[4] = _mm_setzero_si128(); \ + a_reg[5] = _mm_setzero_si128(); \ + a_reg[6] = _mm_setzero_si128(); \ + a_reg[7] = _mm_setzero_si128(); \ + a_reg[8] = _mm_setzero_si128(); \ + a_reg[9] = _mm_setzero_si128(); \ + a_reg[10] = _mm_setzero_si128(); \ + a_reg[11] = _mm_setzero_si128(); \ + a_reg[12] = _mm_setzero_si128(); \ + a_reg[13] = _mm_setzero_si128(); \ + a_reg[14] = _mm_setzero_si128(); \ + a_reg[15] = _mm_setzero_si128(); + +#define UNPACKLOW_EPI8 \ + b_reg[0] = _mm_unpacklo_epi8( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi8( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi8( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi8( a_reg[6], a_reg[7] ); \ + b_reg[4] = _mm_unpacklo_epi8( a_reg[8], a_reg[9] ); \ + b_reg[5] = _mm_unpacklo_epi8( a_reg[10], a_reg[11] ); \ + b_reg[6] = _mm_unpacklo_epi8( a_reg[12], a_reg[13] ); \ + b_reg[7] = _mm_unpacklo_epi8( a_reg[14], a_reg[15] ); + +#define UNPACKHI_EPI8 \ + b_reg[8] = _mm_unpackhi_epi8( a_reg[0], a_reg[1] ); \ + b_reg[9] = _mm_unpackhi_epi8( a_reg[2], a_reg[3] ); \ + b_reg[10] = _mm_unpackhi_epi8( a_reg[4], a_reg[5] ); \ + b_reg[11] = _mm_unpackhi_epi8( a_reg[6], a_reg[7] ); \ + b_reg[12] = _mm_unpackhi_epi8( a_reg[8], a_reg[9] ); \ + b_reg[13] = _mm_unpackhi_epi8( a_reg[10], a_reg[11] ); \ + b_reg[14] = _mm_unpackhi_epi8( a_reg[12], a_reg[13] ); \ + b_reg[15] = _mm_unpackhi_epi8( a_reg[14], a_reg[15] ); + +#define UNPACKLOW_EPI16 \ + a_reg[0] = _mm_unpacklo_epi16( b_reg[0], b_reg[1] ); \ + a_reg[1] = _mm_unpacklo_epi16( b_reg[2], b_reg[3] ); \ + a_reg[2] = _mm_unpacklo_epi16( b_reg[4], b_reg[5] ); \ + a_reg[3] = _mm_unpacklo_epi16( b_reg[6], b_reg[7] ); \ +\ + a_reg[8] = _mm_unpacklo_epi16( b_reg[8], b_reg[9] ); \ + a_reg[9] = _mm_unpacklo_epi16( b_reg[10], b_reg[11] ); \ + a_reg[10] = _mm_unpacklo_epi16( b_reg[12], b_reg[13] ); \ + a_reg[11] = _mm_unpacklo_epi16( b_reg[14], b_reg[15] ); + +#define UNPACKHI_EPI16 \ + a_reg[4] = _mm_unpackhi_epi16( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi16( b_reg[2], b_reg[3] ); \ + a_reg[6] = _mm_unpackhi_epi16( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi16( b_reg[6], b_reg[7] ); \ +\ + a_reg[12] = _mm_unpackhi_epi16( b_reg[8], b_reg[9] ); \ + a_reg[13] = _mm_unpackhi_epi16( b_reg[10], b_reg[11] ); \ + a_reg[14] = _mm_unpackhi_epi16( b_reg[12], b_reg[13] ); \ + a_reg[15] = _mm_unpackhi_epi16( b_reg[14], b_reg[15] ); + +#define UNPACKLOW_EPI32 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi32( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi32( a_reg[6], a_reg[7] ); \ +\ + b_reg[8] = _mm_unpacklo_epi32( a_reg[8], a_reg[9] ); \ + b_reg[9] = _mm_unpacklo_epi32( a_reg[10], a_reg[11] ); \ + b_reg[10] = _mm_unpacklo_epi32( a_reg[12], a_reg[13] ); \ + b_reg[11] = _mm_unpacklo_epi32( a_reg[14], a_reg[15] ); + +#define UNPACKHI_EPI32 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); \ + b_reg[6] = _mm_unpackhi_epi32( a_reg[4], a_reg[5] ); \ + b_reg[7] = _mm_unpackhi_epi32( a_reg[6], a_reg[7] ); \ +\ + b_reg[12] = _mm_unpackhi_epi32( a_reg[8], a_reg[9] ); \ + b_reg[13] = _mm_unpackhi_epi32( a_reg[10], a_reg[11] ); \ + b_reg[14] = _mm_unpackhi_epi32( a_reg[12], a_reg[13] ); \ + b_reg[15] = _mm_unpackhi_epi32( a_reg[14], a_reg[15] ); + +#define UNPACKLOW_EPI64 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[2] = _mm_unpacklo_epi64( b_reg[2], b_reg[3] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); \ + a_reg[6] = _mm_unpacklo_epi64( b_reg[6], b_reg[7] ); \ +\ + a_reg[8] = _mm_unpacklo_epi64( b_reg[8], b_reg[9] ); \ + a_reg[10] = _mm_unpacklo_epi64( b_reg[10], b_reg[11] ); \ + a_reg[12] = _mm_unpacklo_epi64( b_reg[12], b_reg[13] ); \ + a_reg[14] = _mm_unpacklo_epi64( b_reg[14], b_reg[15] ); + +#define UNPACKHI_EPI64 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[3] = _mm_unpackhi_epi64( b_reg[2], b_reg[3] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi64( b_reg[6], b_reg[7] ); \ +\ + a_reg[9] = _mm_unpackhi_epi64( b_reg[8], b_reg[9] ); \ + a_reg[11] = _mm_unpackhi_epi64( b_reg[10], b_reg[11] ); \ + a_reg[13] = _mm_unpackhi_epi64( b_reg[12], b_reg[13] ); \ + a_reg[15] = _mm_unpackhi_epi64( b_reg[14], b_reg[15] ); + +#define UNPACKLOW_EPI16_MR8 \ + a_reg[0] = _mm_unpacklo_epi16( b_reg[0], b_reg[1] ); \ + a_reg[1] = _mm_unpacklo_epi16( b_reg[2], b_reg[3] ); \ + a_reg[2] = _mm_unpacklo_epi16( b_reg[4], b_reg[5] ); \ + a_reg[3] = _mm_unpacklo_epi16( b_reg[6], b_reg[7] ); + +#define UNPACKHI_EPI16_MR8 \ + a_reg[4] = _mm_unpackhi_epi16( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi16( b_reg[2], b_reg[3] ); \ + a_reg[6] = _mm_unpackhi_epi16( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi16( b_reg[6], b_reg[7] ); + +#define UNPACKLOW_EPI32_MR8 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); \ + b_reg[2] = _mm_unpacklo_epi32( a_reg[4], a_reg[5] ); \ + b_reg[3] = _mm_unpacklo_epi32( a_reg[6], a_reg[7] ); + +#define UNPACKHI_EPI32_MR8 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); \ + b_reg[6] = _mm_unpackhi_epi32( a_reg[4], a_reg[5] ); \ + b_reg[7] = _mm_unpackhi_epi32( a_reg[6], a_reg[7] ); + +#define UNPACKLOW_EPI64_MR8 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[2] = _mm_unpacklo_epi64( b_reg[2], b_reg[3] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); \ + a_reg[6] = _mm_unpacklo_epi64( b_reg[6], b_reg[7] ); + +#define UNPACKHI_EPI64_MR8 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[3] = _mm_unpackhi_epi64( b_reg[2], b_reg[3] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); \ + a_reg[7] = _mm_unpackhi_epi64( b_reg[6], b_reg[7] ); + +#define UNPACKLOW_EPI32_MR4 \ + b_reg[0] = _mm_unpacklo_epi32( a_reg[0], a_reg[1] ); \ + b_reg[1] = _mm_unpacklo_epi32( a_reg[2], a_reg[3] ); + +#define UNPACKHI_EPI32_MR4 \ + b_reg[4] = _mm_unpackhi_epi32( a_reg[0], a_reg[1] ); \ + b_reg[5] = _mm_unpackhi_epi32( a_reg[2], a_reg[3] ); + +#define UNPACKLOW_EPI64_MR4 \ + a_reg[0] = _mm_unpacklo_epi64( b_reg[0], b_reg[1] ); \ + a_reg[4] = _mm_unpacklo_epi64( b_reg[4], b_reg[5] ); + +#define UNPACKHI_EPI64_MR4 \ + a_reg[1] = _mm_unpackhi_epi64( b_reg[0], b_reg[1] ); \ + a_reg[5] = _mm_unpackhi_epi64( b_reg[4], b_reg[5] ); + +#define MASKED_STORE_EPI32(mask) \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 0 ) * KC + kr ), mask, a_reg[0] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 1 ) * KC + kr ), mask, a_reg[1] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 2 ) * KC + kr ), mask, a_reg[4] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 3 ) * KC + kr ), mask, a_reg[5] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 4 ) * KC + kr ), mask, a_reg[2] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 5 ) * KC + kr ), mask, a_reg[3] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 6 ) * KC + kr ), mask, a_reg[6] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 7 ) * KC + kr ), mask, a_reg[7] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 8 ) * KC + kr ), mask, a_reg[8] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 9 ) * KC + kr ), mask, a_reg[9] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 10 ) * KC + kr ), mask, a_reg[12] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 11 ) * KC + kr ), mask, a_reg[13] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 12 ) * KC + kr ), mask, a_reg[10] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 13 ) * KC + kr ), mask, a_reg[11] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 14 ) * KC + kr ), mask, a_reg[14] ); \ + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( ic + 15 ) * KC + kr ), mask, a_reg[15] ); + +#define MASKED_STORE_EPI16(mask) \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 0 ) * KC + kr ), mask, a_reg[0] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 1 ) * KC + kr ), mask, a_reg[1] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 2 ) * KC + kr ), mask, a_reg[4] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 3 ) * KC + kr ), mask, a_reg[5] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 4 ) * KC + kr ), mask, a_reg[2] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 5 ) * KC + kr ), mask, a_reg[3] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 6 ) * KC + kr ), mask, a_reg[6] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 7 ) * KC + kr ), mask, a_reg[7] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32+ ( ic + 8 ) * KC + kr ), mask, a_reg[8] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 9 ) * KC + kr ), mask, a_reg[9] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 10 ) * KC + kr ), mask, a_reg[12] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 11 ) * KC + kr ), mask, a_reg[13] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 12 ) * KC + kr ), mask, a_reg[10] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 13 ) * KC + kr ), mask, a_reg[11] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 14 ) * KC + kr ), mask, a_reg[14] ); \ + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( ic + 15 ) * KC + kr ), mask, a_reg[15] ); + +#define MASKED_STORE_EPI8(mask) \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 0 ) * KC + kr ), mask, a_reg[0] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 1 ) * KC + kr ), mask, a_reg[1] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 2 ) * KC + kr ), mask, a_reg[4] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 3 ) * KC + kr ), mask, a_reg[5] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 4 ) * KC + kr ), mask, a_reg[2] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 5 ) * KC + kr ), mask, a_reg[3] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 6 ) * KC + kr ), mask, a_reg[6] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 7 ) * KC + kr ), mask, a_reg[7] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 8 ) * KC + kr ), mask, a_reg[8] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 9 ) * KC + kr ), mask, a_reg[9] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 10 ) * KC + kr ), mask, a_reg[12] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 11 ) * KC + kr ), mask, a_reg[13] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 12 ) * KC + kr ), mask, a_reg[10] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 13 ) * KC + kr ), mask, a_reg[11] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 14 ) * KC + kr ), mask, a_reg[14] ); \ + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( ic + 15 ) * KC + kr ), mask, a_reg[15] ); + + +// Column-major transformation to row-major in blocks of MCxKC + +void packa_mr8_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ); + +void packa_mr4_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ); + +void packa_mrlt4_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC, + const dim_t m_left + ); + +void packa_mr16_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t rs, + const dim_t cs, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + dim_t mr = 16; + __m128i a_reg[16], b_reg[16]; + + dim_t m_partial_pieces = MC % mr; + dim_t k_partial_pieces = KC % 16; + dim_t m_left = MC % 4; + + SET_REGISTERS_ZERO + + dim_t ic, kr; + + for ( ic =0; ( ic + mr - 1 ) < MC; ic += mr ) + { + for ( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + a_reg[0] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + a_reg[4] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 4 ) * cs ) ) ); + a_reg[5] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 5 ) * cs ) ) ); + a_reg[6] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 6 ) * cs ) ) ); + a_reg[7] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 7 ) * cs ) ) ); + a_reg[8] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 8 ) * cs ) ) ); + a_reg[9] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 9 ) * cs ) ) ); + a_reg[10] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 10 ) * cs ) ) ); + a_reg[11] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 11 ) * cs ) ) ); + a_reg[12] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 12 ) * cs ) ) ); + a_reg[13] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 13 ) * cs ) ) ); + a_reg[14] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 14 ) * cs ) ) ); + a_reg[15] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 15 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 3 ) * KC + kr ), a_reg[5] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 4 ) * KC + kr ), a_reg[2] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 5 ) * KC + kr ), a_reg[3] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 6 ) * KC + kr ), a_reg[6] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 7 ) * KC + kr ), a_reg[7] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 8 ) * KC + kr ), a_reg[8] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 9 ) * KC + kr ), a_reg[9] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 10 ) * KC + kr ), a_reg[12] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 11 ) * KC + kr ), a_reg[13] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 12 ) * KC + kr ), a_reg[10] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 13 ) * KC + kr ), a_reg[11] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 14 ) * KC + kr ), a_reg[14] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( ic + 15 ) * KC + kr ), a_reg[15] ); + + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if (( kr + 7 ) < KC ) + { + a_reg[0] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + a_reg[4] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 4 ) * cs ) ) ); + a_reg[5] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 5 ) * cs ) ) ); + a_reg[6] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 6 ) * cs ) ) ); + a_reg[7] = _mm_loadu_si128 ( (__m128i const *) ( a + ( ic * rs ) + ( ( kr + 7 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + MASKED_STORE_EPI32(0x03); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + a_reg[2] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 2 ) * cs ) ) ); + a_reg[3] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 3 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + MASKED_STORE_EPI32(0x01); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + a_reg[1] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 1 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + MASKED_STORE_EPI16(0x01); + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + a_reg[0] = _mm_loadu_si128( (__m128i const *)( a + ( ic * rs ) + ( ( kr + 0 ) * cs ) ) ); + + // Transpose operations + UNPACKLOW_EPI8 + UNPACKHI_EPI8 + + UNPACKLOW_EPI16 + UNPACKHI_EPI16 + + UNPACKLOW_EPI32 + UNPACKHI_EPI32 + + UNPACKLOW_EPI64 + UNPACKHI_EPI64 + + MASKED_STORE_EPI8(0x01); + + kr += 1; + } + } + } + + if( m_partial_pieces > 0 ) + { + if ( ( ic + 8 - 1 ) < MC ) + { + packa_mr8_u8s8s32o32_col_major + ( + ( pack_a_buffer_u8s8s32o32 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC + ); + + ic += 8; + } + + if ( ( ic + 4 - 1 ) < MC ) + { + packa_mr4_u8s8s32o32_col_major + ( + ( pack_a_buffer_u8s8s32o32 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC + ); + + ic += 4; + } + + if ( m_left ) + { + packa_mrlt4_u8s8s32o32_col_major + ( + ( pack_a_buffer_u8s8s32o32 + ( ic * KC ) ), + ( a + ic * rs ), cs, KC, m_left + ); + } + } + + *rs_a = KC; + *cs_a = 4; +} + +void packa_mr8_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ) +{ + //printf("in mr 8 - "); + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + + dim_t k_partial_pieces = KC % 16; + + SET_REGISTERS_ZERO + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 7 ) * cs ) ); + a_reg[8] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 8 ) * cs ) ); + a_reg[9] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 9 ) * cs ) ); + a_reg[10] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 10 ) * cs ) ); + a_reg[11] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 11 ) * cs ) ); + a_reg[12] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 12 ) * cs ) ); + a_reg[13] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 13 ) * cs ) ); + a_reg[14] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 14 ) * cs ) ); + a_reg[15] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 15 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), a_reg[5] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 4 ) * KC + kr ), a_reg[2] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 5 ) * KC + kr ), a_reg[3] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 6 ) * KC + kr ), a_reg[6] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 7 ) * KC + kr ), a_reg[7] ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 7 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x03, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x03, a_reg[4] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x03, a_reg[5] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 4 ) * KC + kr ), 0x03, a_reg[2] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 5 ) * KC + kr ), 0x03, a_reg[3] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 6 ) * KC + kr ), 0x03, a_reg[6] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 7 ) * KC + kr ), 0x03, a_reg[7] ); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 3 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 4 ) * KC + kr ), 0x01, a_reg[2] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 5 ) * KC + kr ), 0x01, a_reg[3] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 6 ) * KC + kr ), 0x01, a_reg[6] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 7 ) * KC + kr ), 0x01, a_reg[7] ); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 1 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 4 ) * KC + kr ), 0x01, a_reg[2] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 5 ) * KC + kr ), 0x01, a_reg[3] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 6 ) * KC + kr ), 0x01, a_reg[6] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 7 ) * KC + kr ), 0x01, a_reg[7] ); + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0xFF, a + ( ( kr + 0 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + UNPACKHI_EPI16_MR8 + + UNPACKLOW_EPI32_MR8 + UNPACKHI_EPI32_MR8 + + UNPACKLOW_EPI64_MR8 + UNPACKHI_EPI64_MR8 + + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 4 ) * KC + kr ), 0x01, a_reg[2] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 5 ) * KC + kr ), 0x01, a_reg[3] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 6 ) * KC + kr ), 0x01, a_reg[6] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 7 ) * KC + kr ), 0x01, a_reg[7] ); + + kr += 1; + } + } +} + + +void packa_mr4_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC + ) +{ + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + + dim_t k_partial_pieces = KC % 16; + + SET_REGISTERS_ZERO + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 7 ) * cs ) ); + a_reg[8] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 8 ) * cs ) ); + a_reg[9] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 9 ) * cs ) ); + a_reg[10] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 10 ) * cs ) ); + a_reg[11] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 11 ) * cs ) ); + a_reg[12] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 12 ) * cs ) ); + a_reg[13] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 13 ) * cs ) ); + a_reg[14] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 14 ) * cs ) ); + a_reg[15] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 15 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), a_reg[4] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), a_reg[5] ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 7 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x03, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x03, a_reg[4] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x03, a_reg[5] ); + + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 3 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 1 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( 0x0F, a + ( ( kr + 0 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 3 ) * KC + kr ), 0x01, a_reg[5] ); + + kr += 1; + } + } +} + +void packa_mrlt4_u8s8s32o32_col_major + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t cs, + const dim_t KC, + const dim_t m_left + ) +{ + __mmask16 mask = 0xFFFF >> ( 16 - m_left ); + dim_t kr = 0; + __m128i a_reg[16], b_reg[16]; + + dim_t k_partial_pieces = KC % 16; + + SET_REGISTERS_ZERO + + for( kr = 0; ( kr + 15 ) < KC; kr += 16 ) + { + a_reg[0] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 7 ) * cs ) ); + a_reg[8] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 8 ) * cs ) ); + a_reg[9] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 9 ) * cs ) ); + a_reg[10] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 10 ) * cs ) ); + a_reg[11] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 11 ) * cs ) ); + a_reg[12] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 12 ) * cs ) ); + a_reg[13] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 13 ) * cs ) ); + a_reg[14] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 14 ) * cs ) ); + a_reg[15] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 15 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), a_reg[1] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), a_reg[4] ); + break; + + case 2: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), a_reg[0] ); + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), a_reg[1] ); + break; + + case 1: + _mm_storeu_si128( (__m128i *)( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), a_reg[0] ); + break; + } + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + // k fringe 8 + if ( ( kr + 7 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 3 ) * cs ) ); + a_reg[4] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 4 ) * cs ) ); + a_reg[5] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 5 ) * cs ) ); + a_reg[6] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 6 ) * cs ) ); + a_reg[7] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 7 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x03, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x03, a_reg[4] ); + break; + + case 2: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x03, a_reg[1] ); + break; + + case 1: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x03, a_reg[0] ); + break; + } + kr += 8; + } + + // k fringe 4 + if ( ( kr + 3 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 1 ) * cs ) ); + a_reg[2] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 2 ) * cs ) ); + a_reg[3] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 3 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + break; + + case 2: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + break; + + case 1: + _mm_mask_storeu_epi32( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + break; + } + kr += 4; + } + + // k fringe 2 + if ( ( kr + 1 ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 0 ) * cs ) ); + a_reg[1] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 1 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + break; + + case 2: + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + break; + + case 1: + _mm_mask_storeu_epi16( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + break; + } + kr += 2; + } + + // k fringe 1 + if ( ( kr ) < KC ) + { + a_reg[0] = _mm_maskz_loadu_epi8( mask, a + ( ( kr + 0 ) * cs ) ); + + // Transpose operations + UNPACKLOW_EPI8 + + UNPACKLOW_EPI16_MR8 + + UNPACKLOW_EPI32_MR4 + UNPACKHI_EPI32_MR4 + + UNPACKLOW_EPI64_MR4 + UNPACKHI_EPI64_MR4 + + switch( m_left ) + { + case 3: + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 2 ) * KC + kr ), 0x01, a_reg[4] ); + break; + + case 2: + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 1 ) * KC + kr ), 0x01, a_reg[1] ); + break; + + case 1: + _mm_mask_storeu_epi8( ( pack_a_buffer_u8s8s32o32 + ( 0 ) * KC + kr ), 0x01, a_reg[0] ); + break; + } + kr += 1; + } + } +} + #endif diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c index 06a1c9ba52..0a87245c90 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,746 +38,2192 @@ #ifdef BLIS_ADDON_LPGEMM -#define NR 64 +#include "lpgemm_s32_pack_macros.h" -void packb_nrlt16_u8s8s32o32 +void packb_nrlt16_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, + const dim_t rs_b, const dim_t KC, - const dim_t n0_partial_rem + const dim_t n0_partial_rem, + bool int4_upscale, + bool signed_upscale ); -void packb_nr16_u8s8s32o32 +void packb_nr16_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ); -void packb_nr32_u8s8s32o32 +void packb_nr32_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ); -void packb_nr48_u8s8s32o32 +void packb_nr48_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ); +void packb_nr64_u8s8s32o32_row_major + ( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p, + bool int4_upscale, + bool signed_upscale + ); + +void packb_nr64_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p); + +void packb_nrlt16_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t KC, + const dim_t n0_partial_rem); + +void packb_nr_mult_16_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t NR, + const dim_t ldb, + const dim_t KC); + +void packb_nrlt16_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem); + void packb_nr64_u8s8s32o32 ( - int8_t* pack_b_buffer_u8s8s32o32, - const int8_t* b, - const dim_t ldb, - const dim_t NC, - const dim_t KC, - dim_t* rs_b, - dim_t* cs_b + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p ) { - // Used for permuting the mm512i elements for use in vpdpbusd instruction. - // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. - // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. - __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); - __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); - - __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); - __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); - - dim_t n_full_pieces = NC / NR; - dim_t n_full_pieces_loop_limit = n_full_pieces * NR; - dim_t n_partial_pieces = NC % NR; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. - dim_t KC_updated = KC; - if ( k_partial_pieces > 0 ) - { - KC_updated += ( 4 - k_partial_pieces ); - } - - __m512i a0; - __m512i b0; - __m512i c0; - __m512i d0; - __m512i a01; - __m512i c01; - - for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) - { - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. - a0 = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( kr + 1 ) ) + jc ); - c0 = _mm512_loadu_si512( b + ( ldb * ( kr + 2 ) ) + jc ); - d0 = _mm512_loadu_si512( b + ( ldb * ( kr + 3 ) ) + jc ); - - a01 = _mm512_unpacklo_epi8( a0, b0 ); - a0 = _mm512_unpackhi_epi8( a0, b0 ); - - c01 = _mm512_unpacklo_epi8( c0, d0 ); - c0 = _mm512_unpackhi_epi8( c0, d0 ); - - b0 = _mm512_unpacklo_epi16( a01, c01 ); - a01 = _mm512_unpackhi_epi16( a01, c01 ); - - d0 = _mm512_unpacklo_epi16( a0, c0 ); - c01 = _mm512_unpackhi_epi16( a0, c0 ); - - a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); - c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); - b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); - d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); - - a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] - c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] - a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] - c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] - - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); - c0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 2 ) ) + jc ); - d0 = _mm512_setzero_si512(); - - } - else if( k_partial_pieces == 2 ) - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); - c0 = _mm512_setzero_si512(); - d0 = _mm512_setzero_si512(); - } - else //k_partial_pieces == 1 - { - a0 = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); - b0 = _mm512_setzero_si512(); - c0 = _mm512_setzero_si512(); - d0 = _mm512_setzero_si512(); - } - - a01 = _mm512_unpacklo_epi8( a0, b0 ); - a0 = _mm512_unpackhi_epi8( a0, b0 ); - - c01 = _mm512_unpacklo_epi8( c0, d0 ); - c0 = _mm512_unpackhi_epi8( c0, d0 ); - - b0 = _mm512_unpacklo_epi16( a01, c01 ); - a01 = _mm512_unpackhi_epi16( a01, c01 ); - - d0 = _mm512_unpacklo_epi16( a0, c0 ); - c01 = _mm512_unpackhi_epi16( a0, c0 ); - - a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); - c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); - b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); - d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); - - a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] - c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] - a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] - c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] - - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); - } - } - - // Contiguous packing of fringe panel (n` < NR). - if ( n_partial_pieces > 0 ) - { - dim_t n0_partial_rem = n_partial_pieces % 16; - dim_t n0_partial_pack = 0; - - // Split into multiple smaller fringe kernels, so as to maximize - // vectorization after packing. Any n0 < NR(64) can be expressed - // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. - dim_t n0_48 = n_partial_pieces / 48; - dim_t n0_32 = n_partial_pieces / 32; - dim_t n0_16 = n_partial_pieces / 16; - - if ( n0_48 == 1 ) - { - packb_nr48_u8s8s32o32 - ( - ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 48; - } - else if ( n0_32 == 1 ) - { - packb_nr32_u8s8s32o32 - ( - ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 32; - } - else if ( n0_16 == 1 ) - { - packb_nr16_u8s8s32o32 - ( - ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), - ( b + n_full_pieces_loop_limit ), ldb, KC - ); - - n0_partial_pack = 16; - } - - if ( n0_partial_rem > 0 ) - { - packb_nrlt16_u8s8s32o32 - ( - ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) + - ( n0_partial_pack * KC_updated ) ), - ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, - n0_partial_rem - ); - } - } - *rs_b = NR * 4; - *cs_b = NR; + if (cs_b == 1) + { + packb_nr64_u8s8s32o32_row_major(pack_b_buffer, + b, rs_b, NC, KC, rs_p, cs_p, + FALSE, FALSE); + } + else + { + packb_nr64_u8s8s32o32_col_major(pack_b_buffer, + b, cs_b, NC, KC, rs_p, cs_p); + } } -void packb_nr48_u8s8s32o32 +void packb_nr64_u8s4s32o32 ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p + ) +{ + if (cs_b == 1) + { + packb_nr64_u8s8s32o32_row_major(pack_b_buffer, + b, rs_b, NC, KC, rs_p, cs_p, + TRUE, TRUE); + } + else + { + bli_print_msg("Only row major supported for int4 packing.", + __FILE__, __LINE__); + return; + } +} + +void packb_nr64_u8s8s32o32_row_major + ( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p, + bool int4_upscale, + bool signed_upscale + ) +{ + + dim_t NR = 64; + + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 4 - k_partial_pieces ); + } + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx_64); + + __m512i sign_comp = _mm512_set1_epi8(0x08); + __mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4. + __mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr); + __m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr); + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + if ( int4_upscale == FALSE ) + { + a0 = _mm512_loadu_si512( b + ( rs_b * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( rs_b * ( kr + 1 ) ) + jc ); + c0 = _mm512_loadu_si512( b + ( rs_b * ( kr + 2 ) ) + jc ); + d0 = _mm512_loadu_si512( b + ( rs_b * ( kr + 3 ) ) + jc ); + } + else + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + __m256i h_a0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 0 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + __m256i h_c0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 2 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_c0, c0, shift_idx_64, \ + sign_comp, signed_upscale); + // If the stride, i.e. rs_b is odd, then the stride increment + // (rs_b * ...)/2 will point at the byte of which the high 4 + // bits is our desired starting element. However since data + // access is at byte level, the low 4 bits of this byte will + // be wrongly included, and additionally the last int4 element + // won't be included either. Extra data movement done to + // account for the same. + // Since kr is a multiple of 4, only kr+1 and kr+3 will have + // the aforementioned issue. + if ( is_odd_stride == FALSE ) + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \ + sign_comp, signed_upscale); + + __m256i h_d0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 3 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_d0, d0, shift_idx_64, \ + sign_comp, signed_upscale); + } + else + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / 2 ) ); + // Only load the last byte/ 32nd byte. + __m256i h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + + __m256i h_d0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 3 ) ) + jc ) / 2 ) ); + __m256i h_d0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( kr + 3 ) ) + jc ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_d0, h_d0_l4bit, d0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + } + } + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( int4_upscale == FALSE ) + { + if ( k_partial_pieces == 3 ) + { + a0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 2 ) ) + jc ); + d0 = _mm512_setzero_si512(); + + } + else if( k_partial_pieces == 2 ) + { + a0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + else //k_partial_pieces == 1 + { + a0 = _mm512_loadu_si512( b + ( rs_b * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_setzero_si512(); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + } + else + { + if ( k_partial_pieces == 3 ) + { + __m256i h_a0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + __m256i h_c0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 2 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_c0, c0, shift_idx_64, \ + sign_comp, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \ + sign_comp, signed_upscale); + } + else + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) ); + __m256i h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + } + + d0 = _mm512_setzero_si512(); + } + else if( k_partial_pieces == 2 ) + { + __m256i h_a0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \ + sign_comp, signed_upscale); + } + else + { + __m256i h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) ); + __m256i h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + jc ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + } + + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + else //k_partial_pieces == 1 + { + __m256i h_a0 = _mm256_maskz_loadu_epi8( hmask, b + + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) / 2 ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_setzero_si512(); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + } + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); + _mm512_storeu_si512( pack_b_buffer + + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); + } + } + + // Contiguous packing of fringe panel (n` < NR). + if ( n_partial_pieces > 0 ) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + dim_t scale_factor = 1; + if ( int4_upscale == TRUE ) + { + scale_factor = 2; + } + + if ( n0_48 == 1 ) + { + packb_nr48_u8s8s32o32_row_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit / scale_factor ) ), rs_b, KC, + int4_upscale, signed_upscale + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_u8s8s32o32_row_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit / scale_factor ) ), rs_b, KC, + int4_upscale, signed_upscale + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_u8s8s32o32_row_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit / scale_factor ) ), rs_b, KC, + int4_upscale, signed_upscale + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_u8s8s32o32_row_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + ( ( n_full_pieces_loop_limit + n0_partial_pack ) / scale_factor ) ), rs_b, KC, + n0_partial_rem, int4_upscale, signed_upscale + ); + } + } + *rs_p = NR * 4; + *cs_p = NR; +} + +void packb_nr48_u8s8s32o32_row_major + ( + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m256i a0_32; - __m256i b0_32; - __m256i c0_32; - __m256i d0_32; - __m256i a01_32; - __m256i c01_32; - __m512i a0_zmm; - __m512i b0_zmm; - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) + ( 32 ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) + ( 32 ) ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) + ( 32 ) ); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); - - // The 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 3; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_32 = _mm256_setzero_si256(); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) ); - d0_16 = _mm_setzero_si128(); - - } - else if( k_partial_pieces == 2 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - else //k_partial_pieces == 1 - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_setzero_si256(); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); - } + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + + // First 32 int4 elements selectors. + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + // Next 16 int4 elements selectors. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + if ( int4_upscale == FALSE ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + \ + ( rs_b * ( kr + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + \ + ( rs_b * ( kr + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + \ + ( rs_b * ( kr + 2 ) ) ); + d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + \ + ( rs_b * ( kr + 3 ) ) ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + \ + ( rs_b * ( kr + 0 ) ) + ( 32 ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + \ + ( rs_b * ( kr + 1 ) ) + ( 32 ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + \ + ( rs_b * ( kr + 2 ) ) + ( 32 ) ); + d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + \ + ( rs_b * ( kr + 3 ) ) + ( 32 ) ); + } + else + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + // First 32 columns. + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_c0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_c0_32, c0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + h_c0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 2 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_32, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_d0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_d0_32, d0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + h_d0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 3 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_d0_32, d0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + + __m128i h_d0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_d0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_d0_32, h_d0_32_l4bit, d0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + + h_d0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 3 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_d0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( kr + 3 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_d0_32, h_d0_32_l4bit, d0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // Next 16 columns. + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 2 ) * NR ), a0_zmm ); + + // The 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( int4_upscale == FALSE ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 0))); + b0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 1))); + c0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 2))); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 0)) + (32)); + b0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 1)) + (32)); + c0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 2)) + (32)); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 0))); + b0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 1))); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 0)) + (32)); + b0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 1)) + (32)); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 0))); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_maskz_loadu_epi8(0xFFFF, + b + (rs_b * (k_full_pieces + 0)) + (32)); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + else + { + if ( k_partial_pieces == 3 ) + { + // First 32 columns. + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_c0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_c0_32, c0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + d0_32 = _mm256_setzero_si256(); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + h_c0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 2 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_32, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + + } + else if( k_partial_pieces == 2 ) + { + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( k_full_pieces + 1 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + else //k_partial_pieces == 1 + { + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 2 ) * NR ), a0_zmm ); + } } -void packb_nr32_u8s8s32o32 +void packb_nr32_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m256i a0_32; - __m256i b0_32; - __m256i c0_32; - __m256i d0_32; - __m256i a01_32; - __m256i c01_32; - __m512i a0_zmm; - __m512i b0_zmm; - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) ); - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - - // The 3rd and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 2; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_32 = _mm256_setzero_si256(); - - } - else if( k_partial_pieces == 2 ) - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - } - else //k_partial_pieces == 1 - { - a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_32 = _mm256_setzero_si256(); - c0_32 = _mm256_setzero_si256(); - d0_32 = _mm256_setzero_si256(); - } - - a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); - a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); - - c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); - c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); - - b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); - a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); - - d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); - c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); - - a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem - c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem - b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem - d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem - - a0_zmm = _mm512_castsi256_si512( a0_32 ); - a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); - b0_zmm = _mm512_castsi256_si512( c0_32 ); - b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); - - // First 4x32 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); - } + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + if ( int4_upscale == FALSE ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( rs_b * ( kr + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( rs_b * ( kr + 1 ) ) ); + c0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( rs_b * ( kr + 2 ) ) ); + d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( rs_b * ( kr + 3 ) ) ); + } + else + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + // First 32 columns. + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_c0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_c0_32, c0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_d0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_d0_32, d0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + + __m128i h_d0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_d0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_d0_32, h_d0_32_l4bit, d0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + } + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // The 3rd and 4th 16byte chunk will be ignored, since its not part of + // the original data,but is here due to the packing in 4 16byte chunks format. + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( int4_upscale == FALSE ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, + b + ( rs_b * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 1))); + c0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 2))); + d0_32 = _mm256_setzero_si256(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 0))); + b0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 1))); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_maskz_loadu_epi8(0xFFFFFFFF, + b + (rs_b * (k_full_pieces + 0))); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + } + else + { + if ( k_partial_pieces == 3 ) + { + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + __m128i h_c0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_c0_32, c0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + d0_32 = _mm256_setzero_si256(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + } + + } + else if( k_partial_pieces == 2 ) + { + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + } + else + { + __m128i h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 16th byte. + __m128i h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, b0_32, \ + shift_idx_32, conv_shift_32, sign_comp_32, signed_upscale); + } + } + else //k_partial_pieces == 1 + { + __m128i h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 1 ) * NR ), b0_zmm ); + } } -void packb_nr16_u8s8s32o32 +void packb_nr16_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, - const dim_t KC + const dim_t rs_b, + const dim_t KC, + bool int4_upscale, + bool signed_upscale ) { - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - __m512i a0_zmm; - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 1 ) ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 2 ) ) ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) ); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - - // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 1; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) ); - d0_16 = _mm_setzero_si128(); - - } - else if( k_partial_pieces == 2 ) - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 1 ) ) ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - else //k_partial_pieces == 1 - { - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - } + dim_t NR = 64; + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + if ( int4_upscale == FALSE ) + { + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( kr + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( kr + 1 ) ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( kr + 2 ) ) ); + d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( kr + 3 ) ) ); + } + else + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_c0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_16, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_d0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_d0_16, d0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + __m128i h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + + __m128i h_d0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + __m128i h_d0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_d0_16, h_d0_16_l4bit, d0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of + // the original data, but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( int4_upscale == FALSE ) + { + if ( k_partial_pieces == 3 ) + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 2 ) ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( rs_b * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + else + { + if ( k_partial_pieces == 3 ) + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_c0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_16, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + __m128i h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + + } + else if( k_partial_pieces == 2 ) + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE ) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + __m128i h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + else //k_partial_pieces == 1 + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } } -void packb_nrlt16_u8s8s32o32 +void packb_nrlt16_u8s8s32o32_row_major ( - int8_t* pack_b_buffer_u8s8s32o32, + int8_t* pack_b_buffer, const int8_t* b, - const dim_t ldb, + const dim_t rs_b, const dim_t KC, - const dim_t n0_partial_rem + const dim_t n0_partial_rem, + bool int4_upscale, + bool signed_upscale ) { - int8_t buf0[16]; - int8_t buf1[16]; - int8_t buf2[16]; - int8_t buf3[16]; - - dim_t kr_new = 0; - - dim_t k_full_pieces_blks = KC / 4; - dim_t k_full_pieces = k_full_pieces_blks * 4; - dim_t k_partial_pieces = KC % 4; - - __m128i a0_16; - __m128i b0_16; - __m128i c0_16; - __m128i d0_16; - __m128i a01_16; - __m128i c01_16; - __m512i a0_zmm; - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) - { - memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); - d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 ); - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - - // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, - // but is here due to the packing in 4 16byte chunks format. - kr_new += 1; - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - if ( k_partial_pieces == 3 ) - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 ); - d0_16 = _mm_setzero_si128(); - - } - else if( k_partial_pieces == 2 ) - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 ); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - else //k_partial_pieces == 1 - { - memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); - - a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 ); - b0_16 = _mm_setzero_si128(); - c0_16 = _mm_setzero_si128(); - d0_16 = _mm_setzero_si128(); - } - - a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); - a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); - - c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); - c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); - - b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem - a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem - d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem - c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem - - __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); - a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); - - // Last 4x16 elements. - _mm512_storeu_si512( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); - } + dim_t NR = 64; + + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + __mmask16 lmask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) ); + + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + // 16 int4 elems in 8 bytes, so adjusting the mask for nr < 16 by + // a factor of 2. In case of odd remainder, the last int4 element + // within the last byte (hi 4 bits) will be ingnored similar to + // padding bits. + __mmask16 hmask_16; + if ( is_odd_stride == FALSE ) + { + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + else + { + if ( ( n0_partial_rem % 2 ) == 0 ) + { + // An interesting property here is that n0_partial_rem is + // guaranteed to be < 16. In that case the largest even n0 + // rem would be 14, and the max number of bytes that will be + // loaded including the extra 4 bit at the beginning will + // only be 7 bytes out of 8. So in any case loading 1 more + // byte will bring the last int4 in the register, while not + // crossing the register boundaries. + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( ( 16 - n0_partial_rem ) / 2 ) - 1 ) ); + } + else + { + // If the n0 rem is odd, and if the starting position is an odd + // index, then the last odd element will also be loaded as part + // of loading the last byte (high 4 bits of last byte). + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + } + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + if ( int4_upscale == FALSE ) + { + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements + // in each row. + a0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( kr + 0 ) ) ) ); + b0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( kr + 1 ) ) ) ); + c0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( kr + 2 ) ) ) ); + d0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( kr + 3 ) ) ) ); + } + else + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_c0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_16, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_d0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_d0_16, d0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // The last int4 elem is already loaded in the previous + // register. Details given in comments about hmask_16. + __m128i h_b0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + + __m128i h_d0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 3 ) ) / 2 ) ); + __m128i h_d0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_d0_16, h_d0_16_l4bit, d0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not + // part of the original data, but is here due to the packing in 4 + // 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( int4_upscale == FALSE ) + { + if ( k_partial_pieces == 3 ) + { + a0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 0 ) ) ) ); + b0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 1 ) ) ) ); + c0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 2 ) ) ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 0 ) ) ) ); + b0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 1 ) ) ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_16 = _mm_maskz_loadu_epi8( lmask, ( b + ( rs_b * ( k_full_pieces + 0 ) ) ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + else + { + if ( k_partial_pieces == 3 ) + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + __m128i h_c0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 2 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_c0_16, c0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // The last int4 elem is already loaded in the previous + // register. Details given in comments about hmask_16. + __m128i h_b0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + else if( k_partial_pieces == 2 ) + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + + if (is_odd_stride == FALSE) + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + __m128i h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 1 ) ) / 2 ) ); + // The last int4 elem is already loaded in the previous + // register. Details given in comments about hmask_16. + __m128i h_b0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, b0_16, \ + shift_idx_16, conv_shift_16, sign_comp_16, signed_upscale); + } + } + else //k_partial_pieces == 1 + { + __m128i h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_si512( pack_b_buffer + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } +} + +void packb_nr64_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t *rs_p, + dim_t *cs_p) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_partial_pieces = KC % 4; + + dim_t KC_updated = KC; + if (k_partial_pieces > 0) + { + KC_updated += (4 - k_partial_pieces); + } + + for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR) + { + packb_nr_mult_16_u8s8s32o32_col_major(pack_b_buffer + (jc * KC_updated), + b + (jc * ldb), 64, ldb, KC); + } + + if (n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if (n0_48 == 1) + { + packb_nr_mult_16_u8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + (b + n_full_pieces_loop_limit * ldb), 48, ldb, KC); + + n0_partial_pack = 48; + } + else if (n0_32 == 1) + { + packb_nr_mult_16_u8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + (b + n_full_pieces_loop_limit * ldb), 32, ldb, KC); + + n0_partial_pack = 32; + } + else if (n0_16 == 1) + { + packb_nr_mult_16_u8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated)), + (b + n_full_pieces_loop_limit * ldb), 16, ldb, KC); + + n0_partial_pack = 16; + } + + if (n0_partial_rem > 0) + { + packb_nrlt16_u8s8s32o32_col_major( + (pack_b_buffer + (n_full_pieces_loop_limit * KC_updated) + + (n0_partial_pack * KC_updated)), + (b + (n_full_pieces_loop_limit + n0_partial_pack) * ldb), ldb, KC, + n0_partial_rem); + } + } + + *rs_p = NR * 4; + *cs_p = NR / 4; +} + +void packb_nr_mult_16_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t NR, + const dim_t ldb, + const dim_t KC) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD); + __m512i selector2 = _mm512_setr_epi64(0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + dim_t kr = 0; + for (kr = 0; (kr + 63) < KC; kr += 64) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + LOAD_16_COLS_AVX512 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 28) * NR), a_reg[7]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 32) * NR), a_reg[8]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 36) * NR), a_reg[9]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 40) * NR), a_reg[10]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 44) * NR), a_reg[11]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 48) * NR), a_reg[12]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 52) * NR), a_reg[13]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 56) * NR), a_reg[14]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 60) * NR), a_reg[15]); + } + } + + for (; (kr + 31) < KC; kr += 32) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64 )0xFFFFFFFF) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 28) * NR), a_reg[7]); + } + } + + for (; (kr + 15) < KC; kr += 16) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0xFFFF) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 12) * NR), a_reg[3]); + } + } + + for (; (kr + 7) < KC; kr += 8) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0xFF) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + ((kr + 4) * NR), a_reg[1]); + } + } + + for (; (kr + 3) < KC; kr += 4) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0x0F) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512(pack_b_buffer + (jr * 4) + (kr * NR), a_reg[0]); + } + } + + for (; (kr + 2) < KC; kr += 3) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0x07) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } + + for (; (kr + 1) < KC; kr += 2) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0x03) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } + + for (; kr < KC; kr += 1) + { + for (dim_t jr = 0; jr < NR; jr += 16) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512((__mmask64)0x01) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512((pack_b_buffer + (jr * 4) + (kr * NR)), a_reg[0]); + } + } } + +void packb_nrlt16_u8s8s32o32_col_major( + int8_t *pack_b_buffer, + const int8_t *b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem) +{ + dim_t NR = 16; + + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD); + __m512i selector2 = _mm512_setr_epi64(0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + dim_t kr = 0, jr = 0; + for (kr = 0; (kr + 63) < KC; kr += 64) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 28) * NR), a_reg[7]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 32) * NR), a_reg[8]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 36) * NR), a_reg[9]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 40) * NR), a_reg[10]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 44) * NR), a_reg[11]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 48) * NR), a_reg[12]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 52) * NR), a_reg[13]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 56) * NR), a_reg[14]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 60) * NR), a_reg[15]); + } + + for (; (kr + 31) < KC; kr += 32) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFFFFFFFF, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 16) * NR), a_reg[4]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 20) * NR), a_reg[5]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 24) * NR), a_reg[6]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 28) * NR), a_reg[7]); + } + + for (; (kr + 15) < KC; kr += 16) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFFFF, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 8) * NR), a_reg[2]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 12) * NR), a_reg[3]); + } + + for (; (kr + 7) < KC; kr += 8) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0xFF, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + _mm512_storeu_si512(pack_b_buffer + ((kr + 4) * NR), a_reg[1]); + } + + for (; (kr + 3) < KC; kr += 4) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x0F, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } + + for (; (kr + 2) < KC; kr += 3) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x07, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } + + for (; (kr + 1) < KC; kr += 2) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x03, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } + + for (; kr < KC; kr += 1) + { + for (jr = 0; jr < n0_partial_rem; jr += 1) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi8(0x01, b + (ldb * (jr + 0)) + kr); + } + + for (; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + _mm512_storeu_si512(pack_b_buffer + ((kr + 0) * NR), a_reg[0]); + } +} + #endif diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h index 1e91381001..ed817c14a4 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,6 +36,7 @@ #define LPGEMM_S32_KERN_MACROS_H #include "../gelu_avx512.h" +#include "../silu_avx512.h" #include "../math_utils_avx512.h" #define S32_BETA_FMA(reg,scratch1,scratch2) \ @@ -97,7 +98,7 @@ S32_BETA_FMA(reg,scratch1,scratch2) \ // Default n < 16 mask load beta macro -#define S32_S32_BETA_OP_NLT16F_MASK(lmask,reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ +#define S32_S32_BETA_OP_NLT16F_MASK(c,lmask,reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ scratch1 = _mm512_maskz_loadu_epi32( lmask, c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \ S32_BETA_FMA(reg,scratch1,scratch2) \ @@ -161,7 +162,7 @@ ); \ reg = _mm512_add_epi32( reg, _mm512_cvtepi8_epi32( zero_point ) ); \ -/* TANH GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */ +/* TANH GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */ #define GELU_TANH_S32_AVX512(reg, y, r, r2, x, z, dn, x_tanh, q) \ \ y = _mm512_cvtepi32_ps( reg ); \ @@ -183,7 +184,7 @@ \ reg = _mm512_min_epi32( _mm512_max_epi32( reg, min ), max ); \ -// Load helper macros. +// Gelu load helper macros. #define S32_GELU_LOAD1R_1C(temp_buf,offset,stride,reg_base) \ _mm512_storeu_si512( ( temp_buf ) + ( ( 0 + offset ) * ( stride ) ), reg_base ## p0); \ @@ -202,7 +203,7 @@ _mm512_storeu_si512( ( temp_buf ) + ( ( 2 + offset ) * ( stride ) ), reg_base ## p2); \ _mm512_storeu_si512( ( temp_buf ) + ( ( 3 + offset ) * ( stride ) ), reg_base ## p3); \ -// Store helper macros. +// Gelu store helper macros. #define S32_GELU_STORE1R_1C(temp_buf,offset,stride,reg_base) \ reg_base ## p0 = _mm512_loadu_si512( ( temp_buf ) + ( ( 0 + offset ) * ( stride ) ) ); \ @@ -221,4 +222,122 @@ reg_base ## p2 = _mm512_loadu_si512( ( temp_buf ) + ( ( 2 + offset ) * ( stride ) ) ); \ reg_base ## p3 = _mm512_loadu_si512( ( temp_buf ) + ( ( 3 + offset ) * ( stride ) ) ); \ +// Matrix Add post-ops helper macros +#define S32_MATRIX_ADD_1COL(scr0,m_ind) \ + c_int32_ ## m_ind ## p0 = _mm512_add_epi32( scr0, c_int32_ ## m_ind ## p0 ); \ + +#define S32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + c_int32_ ## m_ind ## p0 = _mm512_add_epi32( scr0, c_int32_ ## m_ind ## p0 ); \ + c_int32_ ## m_ind ## p1 = _mm512_add_epi32( scr1, c_int32_ ## m_ind ## p1 ); \ + +#define S32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + c_int32_ ## m_ind ## p0 = _mm512_add_epi32( scr0, c_int32_ ## m_ind ## p0 ); \ + c_int32_ ## m_ind ## p1 = _mm512_add_epi32( scr1, c_int32_ ## m_ind ## p1 ); \ + c_int32_ ## m_ind ## p2 = _mm512_add_epi32( scr2, c_int32_ ## m_ind ## p2 ); \ + +#define S32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + c_int32_ ## m_ind ## p0 = _mm512_add_epi32( scr0, c_int32_ ## m_ind ## p0 ); \ + c_int32_ ## m_ind ## p1 = _mm512_add_epi32( scr1, c_int32_ ## m_ind ## p1 ); \ + c_int32_ ## m_ind ## p2 = _mm512_add_epi32( scr2, c_int32_ ## m_ind ## p2 ); \ + c_int32_ ## m_ind ## p3 = _mm512_add_epi32( scr3, c_int32_ ## m_ind ## p3 ); \ + +#define S8_S32_MATRIX_ADD_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_cvtepi8_epi32 \ + ( \ + _mm_maskz_loadu_epi8 \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ) \ + ); \ + +#define S8_S32_MATRIX_ADD_1COL_PAR(mask,scr0,m_ind) \ + S8_S32_MATRIX_ADD_LOAD(mask,scr0,m_ind,0); \ + S32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S8_S32_MATRIX_ADD_1COL(scr0,m_ind) \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S8_S32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S32_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define S8_S32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + S32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind); \ + +#define S8_S32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + S8_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + S32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +#define S32_S32_MATRIX_ADD_LOAD(mask,scr,m_ind,n_ind) \ + scr = _mm512_maskz_loadu_epi32 \ + ( \ + mask, \ + matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) \ + ); \ + +#define S32_S32_MATRIX_ADD_1COL_PAR(mask,scr0,m_ind) \ + S32_S32_MATRIX_ADD_LOAD(mask,scr0,m_ind,0); \ + S32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S32_S32_MATRIX_ADD_1COL(scr0,m_ind) \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S32_MATRIX_ADD_1COL(scr0,m_ind); \ + +#define S32_S32_MATRIX_ADD_2COL(scr0,scr1,m_ind) \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S32_MATRIX_ADD_2COL(scr0,scr1,m_ind); \ + +#define S32_S32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind) \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + S32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind); \ + +#define S32_S32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind) \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,m_ind,0); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,m_ind,1); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,m_ind,2); \ + S32_S32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,m_ind,3); \ + S32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind); \ + +// SiLU utility macros. al register expected to contains floats. +#define SWISH_S32_AVX512(in_reg, fl_reg, al, al_in, r, r2, z, dn, ex_out) \ + fl_reg = _mm512_cvtepi32_ps( in_reg ); \ + SWISH_F32_AVX512_DEF( fl_reg, al, al_in, r, r2, z, dn, ex_out); \ + in_reg = _mm512_cvtps_epi32( fl_reg ); \ + +//Zero-out the given ZMM accumulator registers +#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \ + zmm0 = _mm512_setzero_epi32(); \ + zmm1 = _mm512_setzero_epi32(); \ + zmm2 = _mm512_setzero_epi32(); \ + zmm3 = _mm512_setzero_epi32(); + +#define ZERO_ACC_XMM_4_REG(zmm0,zmm1,zmm2,zmm3) \ + zmm0 = _mm_setzero_si128 (); \ + zmm1 = _mm_setzero_si128 (); \ + zmm2 = _mm_setzero_si128 (); \ + zmm3 = _mm_setzero_si128 (); + +#define CVT_STORE_S32_S8_MASK(reg,mask,m_ind,n_ind) \ + _mm512_mask_cvtsepi32_storeu_epi8 \ + ( \ + ( int8_t* )post_ops_attr.buf_downscale + \ + ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ), \ + mask, reg \ + ); \ + #endif // LPGEMM_S32_KERN_MACROS_H diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_memcpy_macros.h b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_memcpy_macros.h index fc5f0158b7..003e3dd996 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_memcpy_macros.h +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_memcpy_macros.h @@ -4,19 +4,19 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h new file mode 100644 index 0000000000..6a51828eca --- /dev/null +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h @@ -0,0 +1,150 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_S32_PACK_MACROS_H +#define LPGEMM_S32_PACK_MACROS_H + +#include "../int4_utils_avx512.h" + +#define LOAD_16_COLS_AVX512 \ + a_reg[0] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); \ + a_reg[1] = _mm512_loadu_si512(b + (ldb * (jr + 1)) + kr); \ + a_reg[2] = _mm512_loadu_si512(b + (ldb * (jr + 2)) + kr); \ + a_reg[3] = _mm512_loadu_si512(b + (ldb * (jr + 3)) + kr); \ + a_reg[4] = _mm512_loadu_si512(b + (ldb * (jr + 4)) + kr); \ + a_reg[5] = _mm512_loadu_si512(b + (ldb * (jr + 5)) + kr); \ + a_reg[6] = _mm512_loadu_si512(b + (ldb * (jr + 6)) + kr); \ + a_reg[7] = _mm512_loadu_si512(b + (ldb * (jr + 7)) + kr); \ + a_reg[8] = _mm512_loadu_si512(b + (ldb * (jr + 8)) + kr); \ + a_reg[9] = _mm512_loadu_si512(b + (ldb * (jr + 9)) + kr); \ + a_reg[10] = _mm512_loadu_si512(b + (ldb * (jr + 10)) + kr); \ + a_reg[11] = _mm512_loadu_si512(b + (ldb * (jr + 11)) + kr); \ + a_reg[12] = _mm512_loadu_si512(b + (ldb * (jr + 12)) + kr); \ + a_reg[13] = _mm512_loadu_si512(b + (ldb * (jr + 13)) + kr); \ + a_reg[14] = _mm512_loadu_si512(b + (ldb * (jr + 14)) + kr); \ + a_reg[15] = _mm512_loadu_si512(b + (ldb * (jr + 15)) + kr); + +#define UNPACKHILO32_AVX512 \ + b_reg[0] = _mm512_unpacklo_epi32(a_reg[0], a_reg[1]); \ + b_reg[2] = _mm512_unpacklo_epi32(a_reg[2], a_reg[3]); \ + b_reg[4] = _mm512_unpacklo_epi32(a_reg[4], a_reg[5]); \ + b_reg[6] = _mm512_unpacklo_epi32(a_reg[6], a_reg[7]); \ + b_reg[8] = _mm512_unpacklo_epi32(a_reg[8], a_reg[9]); \ + b_reg[10] = _mm512_unpacklo_epi32(a_reg[10], a_reg[11]); \ + b_reg[12] = _mm512_unpacklo_epi32(a_reg[12], a_reg[13]); \ + b_reg[14] = _mm512_unpacklo_epi32(a_reg[14], a_reg[15]); \ + \ + b_reg[1] = _mm512_unpackhi_epi32(a_reg[0], a_reg[1]); \ + b_reg[3] = _mm512_unpackhi_epi32(a_reg[2], a_reg[3]); \ + b_reg[5] = _mm512_unpackhi_epi32(a_reg[4], a_reg[5]); \ + b_reg[7] = _mm512_unpackhi_epi32(a_reg[6], a_reg[7]); \ + b_reg[9] = _mm512_unpackhi_epi32(a_reg[8], a_reg[9]); \ + b_reg[11] = _mm512_unpackhi_epi32(a_reg[10], a_reg[11]); \ + b_reg[13] = _mm512_unpackhi_epi32(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm512_unpackhi_epi32(a_reg[14], a_reg[15]); + +#define UNPACKHILO64_AVX512 \ + a_reg[0] = _mm512_unpacklo_epi64(b_reg[0], b_reg[2]); \ + a_reg[1] = _mm512_unpacklo_epi64(b_reg[4], b_reg[6]); \ + a_reg[2] = _mm512_unpacklo_epi64(b_reg[8], b_reg[10]); \ + a_reg[3] = _mm512_unpacklo_epi64(b_reg[12], b_reg[14]); \ + a_reg[4] = _mm512_unpacklo_epi64(b_reg[1], b_reg[3]); \ + a_reg[5] = _mm512_unpacklo_epi64(b_reg[5], b_reg[7]); \ + a_reg[6] = _mm512_unpacklo_epi64(b_reg[9], b_reg[11]); \ + a_reg[7] = _mm512_unpacklo_epi64(b_reg[13], b_reg[15]); \ + \ + a_reg[8] = _mm512_unpackhi_epi64(b_reg[0], b_reg[2]); \ + a_reg[9] = _mm512_unpackhi_epi64(b_reg[4], b_reg[6]); \ + a_reg[10] = _mm512_unpackhi_epi64(b_reg[8], b_reg[10]); \ + a_reg[11] = _mm512_unpackhi_epi64(b_reg[12], b_reg[14]); \ + a_reg[12] = _mm512_unpackhi_epi64(b_reg[1], b_reg[3]); \ + a_reg[13] = _mm512_unpackhi_epi64(b_reg[5], b_reg[7]); \ + a_reg[14] = _mm512_unpackhi_epi64(b_reg[9], b_reg[11]); \ + a_reg[15] = _mm512_unpackhi_epi64(b_reg[13], b_reg[15]); + +#define PERMUTEX2_VAR64_AVX512 \ + b_reg[0] = _mm512_permutex2var_epi64(a_reg[0], selector1, a_reg[1]); \ + b_reg[1] = _mm512_permutex2var_epi64(a_reg[2], selector1, a_reg[3]); \ + b_reg[2] = _mm512_permutex2var_epi64(a_reg[8], selector1, a_reg[9]); \ + b_reg[3] = _mm512_permutex2var_epi64(a_reg[10], selector1, a_reg[11]); \ + b_reg[4] = _mm512_permutex2var_epi64(a_reg[4], selector1, a_reg[5]); \ + b_reg[5] = _mm512_permutex2var_epi64(a_reg[6], selector1, a_reg[7]); \ + b_reg[6] = _mm512_permutex2var_epi64(a_reg[12], selector1, a_reg[13]); \ + b_reg[7] = _mm512_permutex2var_epi64(a_reg[14], selector1, a_reg[15]); \ + b_reg[8] = _mm512_permutex2var_epi64(a_reg[0], selector2, a_reg[1]); \ + b_reg[9] = _mm512_permutex2var_epi64(a_reg[2], selector2, a_reg[3]); \ + b_reg[10] = _mm512_permutex2var_epi64(a_reg[8], selector2, a_reg[9]); \ + b_reg[11] = _mm512_permutex2var_epi64(a_reg[10], selector2, a_reg[11]); \ + b_reg[12] = _mm512_permutex2var_epi64(a_reg[4], selector2, a_reg[5]); \ + b_reg[13] = _mm512_permutex2var_epi64(a_reg[6], selector2, a_reg[7]); \ + b_reg[14] = _mm512_permutex2var_epi64(a_reg[12], selector2, a_reg[13]); \ + b_reg[15] = _mm512_permutex2var_epi64(a_reg[14], selector2, a_reg[15]); + +#define SHUFFLE64x2_AVX512 \ + a_reg[0] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0x44); \ + a_reg[1] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0x44); \ + a_reg[2] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0x44); \ + a_reg[3] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0x44); \ + a_reg[4] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0x44); \ + a_reg[5] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0x44); \ + a_reg[6] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0x44); \ + a_reg[7] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0x44); \ + a_reg[8] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0xEE); \ + a_reg[9] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0xEE); \ + a_reg[10] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0xEE); \ + a_reg[11] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0xEE); \ + a_reg[12] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0xEE); \ + a_reg[13] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0xEE); \ + a_reg[14] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0xEE); \ + a_reg[15] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0xEE); + +#define MASK_LOAD_16_COLS_AVX512(mask) \ + a_reg[0] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 0)) + kr); \ + a_reg[1] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 1)) + kr); \ + a_reg[2] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 2)) + kr); \ + a_reg[3] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 3)) + kr); \ + a_reg[4] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 4)) + kr); \ + a_reg[5] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 5)) + kr); \ + a_reg[6] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 6)) + kr); \ + a_reg[7] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 7)) + kr); \ + a_reg[8] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 8)) + kr); \ + a_reg[9] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 9)) + kr); \ + a_reg[10] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 10)) + kr); \ + a_reg[11] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 11)) + kr); \ + a_reg[12] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 12)) + kr); \ + a_reg[13] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 13)) + kr); \ + a_reg[14] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 14)) + kr); \ + a_reg[15] = _mm512_maskz_loadu_epi8(mask, b + (ldb * (jr + 15)) + kr); + +#endif //LPGEMM_S32_PACK_MACROS_H diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemv_m_kernel_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemv_m_kernel_amd512vnni.c new file mode 100644 index 0000000000..d3405731a6 --- /dev/null +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemv_m_kernel_amd512vnni.c @@ -0,0 +1,548 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_s32_kern_macros.h" +#include "lpgemm_s32_memcpy_macros.h" + +LPGEMV_M_EQ1_KERN(uint8_t, int8_t, int32_t, u8s8s32os32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 + }; + + const uint8_t *a_use = NULL; + const int8_t *b_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for( dim_t jr = 0; jr < n0; jr += NR ) + { + NR = bli_min( 64, ( ( n0 - jr ) / 16 ) * 16 ); + + if( NR == 0 ) NR = 16; + + rs_b = NR * 4; + dim_t nr0 = bli_min( n0 - jr, NR ); + + int32_t* c_use = c + jr * cs_c; + + __mmask16 k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF, k4 = 0xFFFF; + __mmask32 k5 = 0xFFFFFFFF, k6 = 0xFFFFFFFF; + __mmask32 k7 = 0xFFFFFFFF, k8 = 0xFFFFFFFF; + + + if( nr0 == 64 ) + { + + } + if( nr0 == 48 ) + { + k4 = k8 = 0x0; + } + else if( nr0 == 32 ) + { + k3 = k4 = k7 = k8 = 0x0; + } + else if( nr0 == 16 ) + { + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + else if( nr0 < 16 ) + { + k1 = (0xFFFF >> (16 - (nr0 & 0x0F))); + k2 = k3 = k4 = k6 = k7 = k8 = 0; + } + + + + __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512i zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512i zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512i zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512i zmm29, zmm30, zmm31; + + // zero the accumulator registers + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + + dim_t k_full_pieces = kc0 / 4; + dim_t k_partial_pieces = kc0 % 4; + + dim_t k_iter = kc0 / 16; + dim_t k_rem = k_full_pieces % 4; + + dim_t kc0_updated = kc0; + + if ( k_partial_pieces > 0 ) + { + kc0_updated += ( 4 - k_partial_pieces ); + } + + b_use = b + (n_sub_updated * pc) + + ( ( jc_cur_loop_rem + jr ) * kc0_updated ); + + a_use = a + pc; + + + for( dim_t kr = 0; kr < k_iter; kr++ ) + { + // load first 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k5, b_use + rs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k5, b_use + 2 * rs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k5, b_use + 3 * rs_b ); + b_use += 64; + + // Broadcast col0-col3 elements of A + zmm4 = _mm512_set1_epi32( *( int32_t* )( a_use ) ); + zmm5 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a ) ); + zmm6 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a * 2 ) ); + zmm7 = _mm512_set1_epi32( *( int32_t* )( a_use + cs_a * 3 ) ); + + // Load second 4x64 tile from row 0-3 + zmm24 = _mm512_maskz_loadu_epi16( k6, b_use ); + zmm25 = _mm512_maskz_loadu_epi16( k6, b_use + rs_b ); + zmm26 = _mm512_maskz_loadu_epi16( k6, b_use + 2 * rs_b ); + zmm27 = _mm512_maskz_loadu_epi16( k6, b_use + 3 * rs_b ); + b_use += 64; + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm9 = _mm512_dpbusd_epi32( zmm9, zmm5, zmm1 ); + zmm10 = _mm512_dpbusd_epi32( zmm10, zmm6, zmm2 ); + zmm11 = _mm512_dpbusd_epi32( zmm11, zmm7, zmm3 ); + + // load third 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k7, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k7, b_use + rs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * rs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k7, b_use + 3 * rs_b ); + b_use += 64; + + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm24 ); + zmm13 = _mm512_dpbusd_epi32( zmm13, zmm5, zmm25 ); + zmm14 = _mm512_dpbusd_epi32( zmm14, zmm6, zmm26 ); + zmm15 = _mm512_dpbusd_epi32( zmm15, zmm7, zmm27 ); + + // load third 4x64 tile from row 0-3 + zmm28 = _mm512_maskz_loadu_epi16( k8, b_use ); + zmm29 = _mm512_maskz_loadu_epi16( k8, b_use + rs_b ); + zmm30 = _mm512_maskz_loadu_epi16( k8, b_use + 2 * rs_b ); + zmm31 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * rs_b ); + + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm0 ); + zmm17 = _mm512_dpbusd_epi32( zmm17, zmm5, zmm1 ); + zmm18 = _mm512_dpbusd_epi32( zmm18, zmm6, zmm2 ); + zmm19 = _mm512_dpbusd_epi32( zmm19, zmm7, zmm3 ); + + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm28 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm5, zmm29 ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm6, zmm30 ); + zmm23 = _mm512_dpbusd_epi32( zmm23, zmm7, zmm31 ); + + b_use -= 192; // move b point back to start of KCXNR + b_use += ( 4 * rs_b ); + a_use += 4 * cs_a; // move a pointer to next col + } + for( dim_t kr = 0; kr < k_rem; kr++ ) + { + // load first 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + zmm1 = _mm512_maskz_loadu_epi16( k6, b_use + cs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * cs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * cs_b ); + + // Broadcast col0 elements of A + zmm4 = _mm512_set1_epi32( *( int32_t* )( a_use ) ); + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm3 ); + + b_use += rs_b; + a_use += cs_a; // move a pointer to next col + } + if( k_partial_pieces > 0 ) + { + __m128i a_kfringe_buf; + __mmask16 load_mask = + _cvtu32_mask16( 0xFFFF >> ( 16 - k_partial_pieces ) ); + + // load first 4x64 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_epi16( k5, b_use ); + + // Broadcast a[0,kr:kr+4]. + a_kfringe_buf = _mm_maskz_loadu_epi8( load_mask, a_use ); + zmm4 = _mm512_broadcastd_epi32( a_kfringe_buf ); + + zmm1 = _mm512_maskz_loadu_epi16( k6, b_use + cs_b ); + zmm2 = _mm512_maskz_loadu_epi16( k7, b_use + 2 * cs_b ); + zmm3 = _mm512_maskz_loadu_epi16( k8, b_use + 3 * cs_b ); + + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm4, zmm0 ); + zmm12 = _mm512_dpbusd_epi32( zmm12, zmm4, zmm1 ); + zmm16 = _mm512_dpbusd_epi32( zmm16, zmm4, zmm2 ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm4, zmm3 ); + + } + + } + + // Sumup k-unroll outputs + zmm8 = _mm512_add_epi32( zmm9, zmm8 ); + zmm10 = _mm512_add_epi32(zmm11, zmm10); + zmm8 = _mm512_add_epi32(zmm10, zmm8); // 64 outputs + + zmm12 = _mm512_add_epi32(zmm13, zmm12); + zmm14 = _mm512_add_epi32(zmm15, zmm14); + zmm12 = _mm512_add_epi32(zmm14, zmm12); // 64 outputs + + zmm16 = _mm512_add_epi32(zmm17, zmm16); + zmm18 = _mm512_add_epi32(zmm19, zmm18); + zmm16 = _mm512_add_epi32(zmm18, zmm16); // 64 outputs + + zmm20 = _mm512_add_epi32(zmm21, zmm20); + zmm22 = _mm512_add_epi32(zmm23, zmm22); + zmm20 = _mm512_add_epi32(zmm22, zmm20); // 64 outputs + + + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + __m512i selector3 = _mm512_setzero_epi32(); + __m512i selector4 = _mm512_setzero_epi32(); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mullo_epi32(selector1, zmm8); + zmm12 = _mm512_mullo_epi32(selector1, zmm12); + zmm16 = _mm512_mullo_epi32(selector1, zmm16); + zmm20 = _mm512_mullo_epi32(selector1, zmm20); + + if (beta != 0) + { + // For the downscaled api (C-s8), the output C matrix values + // needs to be upscaled to s32 to be used for beta scale. + if ( post_ops_attr.buf_downscale != NULL ) + { + S8_S32_BETA_OP_NLT16F_MASK( k1, zmm8, 0, 0, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k2, zmm12, 0, 1, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k3, zmm16, 0, 2, + selector1, selector2 ) + S8_S32_BETA_OP_NLT16F_MASK( k4, zmm20, 0, 3, + selector1, selector2 ) + } + else + { + S32_S32_BETA_OP_NLT16F_MASK( c_use, k1, zmm8, 0, 0, 0, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k2, zmm12, 0, 0, 1, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k3, zmm16, 0, 0, 2, + selector1, selector2 ) + S32_S32_BETA_OP_NLT16F_MASK( c_use, k4, zmm20, 0, 0, 3, + selector1, selector2 ) + } + } + + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_si512( ( int32_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + zmm12 = _mm512_add_epi32( selector2, zmm12 ); + zmm16 = _mm512_add_epi32( selector3, zmm16 ); + zmm20 = _mm512_add_epi32( selector4, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + zmm8 = _mm512_max_epi32( selector1, zmm8 ); + zmm12 = _mm512_max_epi32( selector1, zmm12 ); + zmm16 = _mm512_max_epi32( selector1, zmm16 ); + zmm20 = _mm512_max_epi32( selector1, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + RELU_SCALE_OP_S32_AVX512( zmm8 ) + RELU_SCALE_OP_S32_AVX512( zmm12 ) + RELU_SCALE_OP_S32_AVX512( zmm16 ) + RELU_SCALE_OP_S32_AVX512( zmm20 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, y, x_tanh; + + GELU_TANH_S32_AVX512( zmm8, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm12, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm16, y, r, r2, x, + z, dn, x_tanh, selector1 ) + GELU_TANH_S32_AVX512( zmm20, y, r, r2, x, + z, dn, x_tanh, selector1 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, y, x_erf; + + GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm12, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm16, y, r, x, x_erf ) + GELU_ERF_S32_AVX512( zmm20, y, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + + } + POST_OPS_CLIP_6x64: + { + __m512i min = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args2 ); + __m512i max = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args3 ); + + CLIP_S32_AVX512( zmm8, min, max ) + CLIP_S32_AVX512( zmm12, min, max ) + CLIP_S32_AVX512( zmm16, min, max ) + CLIP_S32_AVX512( zmm20, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6x64: + { + if ( post_ops_list_temp->scale_factor_len > 1 ) + { + selector1 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_si512( (float*)post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + } + else if ( post_ops_list_temp->scale_factor_len == 1 ) + { + selector1 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector2 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector3 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + selector4 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + } + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point1 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point2 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + __m128i zero_point3 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + + // int8_t zero point value. + if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) + { + zero_point0 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + zero_point1 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + zero_point2 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + zero_point3 = _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + } + else if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) + { + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point1 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point2 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + zero_point3 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + } + + CVT_MULRND_CVT32(zmm8, selector1, zero_point0 ); + CVT_MULRND_CVT32(zmm12, selector2, zero_point1 ); + CVT_MULRND_CVT32(zmm16, selector3, zero_point2 ); + CVT_MULRND_CVT32(zmm20, selector4, zero_point3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector1, 0, 0 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector2, 0, 1 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector3, 0, 2 ); + S8_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector4, 0, 3 ); + + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector1, 0, 0 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector2, 0, 1 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector3, 0, 2 ); + S32_S32_MATRIX_ADD_LOAD( _cvtu32_mask16( 0xFFFF ), + selector4, 0, 3 ); + } + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + zmm12 = _mm512_add_epi32( selector2, zmm12 ); + zmm16 = _mm512_add_epi32( selector3, zmm16 ); + zmm20 = _mm512_add_epi32( selector4, zmm20 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm12, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm16, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + SWISH_S32_AVX512( zmm20, fl_reg, al, al_in, + r, r2, z, dn, selector2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64_DISABLE: + { + if ( post_ops_attr.buf_downscale != NULL ) + { + CVT_STORE_S32_S8_MASK( zmm8, k1, 0, 0 ); + CVT_STORE_S32_S8_MASK( zmm12, k2, 0, 1 ); + CVT_STORE_S32_S8_MASK( zmm16, k3, 0, 2 ); + CVT_STORE_S32_S8_MASK( zmm20, k4, 0, 3 ); + } + else + { + _mm512_mask_storeu_epi32( c_use + ( 0*16 ), k1, zmm8 ); + _mm512_mask_storeu_epi32( c_use + ( 1*16 ), k2, zmm12 ); + _mm512_mask_storeu_epi32( c_use + ( 2*16 ), k3, zmm16 ); + _mm512_mask_storeu_epi32( c_use + ( 3*16 ), k4, zmm20 ); + } + } + + post_ops_attr.post_op_c_j += nr0; + + } // jr loop +} +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemv_n_kernel_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemv_n_kernel_amd512vnni.c new file mode 100644 index 0000000000..3406d79745 --- /dev/null +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemv_n_kernel_amd512vnni.c @@ -0,0 +1,726 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_s32_kern_macros.h" +#include "lpgemm_s32_memcpy_macros.h" + + +#define LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, paddr, stride ) \ + zmm0 = _mm512_loadu_si512( paddr ); \ + zmm1 = _mm512_loadu_si512( paddr + stride ); \ + zmm2 = _mm512_loadu_si512( paddr + 2 * stride ); \ + zmm3 = _mm512_loadu_si512( paddr + 3 * stride ); + +#define LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, \ + zmm3, k1, paddr, stride ) \ + zmm0 = _mm512_maskz_loadu_epi8( k1, paddr ); \ + zmm1 = _mm512_maskz_loadu_epi8( k1, paddr + stride ); \ + zmm2 = _mm512_maskz_loadu_epi8( k1, paddr + 2 * stride ); \ + zmm3 = _mm512_maskz_loadu_epi8( k1, paddr + 3 * stride ); + +#define LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, \ + zmm6, zmm0, zmm1, zmm2, zmm3 ) \ + zmm8 = _mm512_dpbusd_epi32( zmm8, zmm0, zmm6 ); \ + zmm9 = _mm512_dpbusd_epi32( zmm9, zmm1, zmm6 ); \ + zmm10 = _mm512_dpbusd_epi32( zmm10, zmm2, zmm6 ); \ + zmm11 = _mm512_dpbusd_epi32( zmm11, zmm3, zmm6 ); + +#define LPGEMV_ZMM2XMM( zmm0, zmm1, zmm2, zmm3, \ + ymm0, ymm1, ymm2, ymm3, xmm0) \ + ymm0 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm0, 0x0), \ + _mm512_extracti32x8_epi32 (zmm0, 0x1)); \ + ymm1 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm1, 0x0), \ + _mm512_extracti32x8_epi32 (zmm1, 0x1)); \ + ymm0 = _mm256_hadd_epi32 (ymm0, ymm1); \ + ymm2 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm2, 0x0), \ + _mm512_extracti32x8_epi32 (zmm2, 0x1)); \ + ymm3 = _mm256_add_epi32 (_mm512_extracti32x8_epi32 (zmm3, 0x0), \ + _mm512_extracti32x8_epi32 (zmm3, 0x1)); \ + ymm1 = _mm256_hadd_epi32 (ymm2, ymm3); \ + ymm0 = _mm256_hadd_epi32 (ymm0, ymm1); \ + xmm0 = _mm_add_epi32 ( _mm256_extracti128_si256 (ymm0, 0), \ + _mm256_extracti128_si256 (ymm0,1)); + +#define CVT_STORE_S32_S8_MASK(reg,mask,m_ind,n_ind) \ + _mm512_mask_cvtsepi32_storeu_epi8 \ + ( \ + ( int8_t* )post_ops_attr.buf_downscale + \ + ( post_ops_attr.rs_c_downscale * \ + ( post_ops_attr.post_op_c_i + m_ind ) ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ), \ + mask, reg \ + ); \ + +LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int32_t, u8s8s32os32) +{ + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_GELU_TANH_6x64, + &&POST_OPS_GELU_ERF_6x64, + &&POST_OPS_CLIP_6x64, + &&POST_OPS_DOWNSCALE_6x64, + &&POST_OPS_MATRIX_ADD_6x64, + &&POST_OPS_SWISH_6x64 + }; + + const uint8_t *a_use = NULL; + const int8_t *b_use = NULL; + int32_t *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for ( dim_t ir = 0; ir < m0; ir += MR ) + { + dim_t mr0 = bli_min( ( m0 - ir ), MR ); + dim_t k_iter = k/64; + dim_t k_rem = k & 0x3F; + + //Create load mask for k fringe + __mmask64 k1 = 0xFFFFFFFFFFFFFFFF; + if( k_rem ) + { + k1 = ( k1 >> ( 64 - k_rem ) ); + } + + // Create store mask for C for mr fringe + __mmask16 k2 = 0xFFFF; + if ( mr0 < MR ) + { + k2 = ( 0xFFFF >> ( MR - mr0 ) ); + } + + __m512i zmm0, zmm1, zmm2, zmm3, zmm6; + __m512i zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512i zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512i zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512i zmm29, zmm30, zmm31; + + __m256i ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6; + __m128i xmm0, xmm1, xmm2, xmm3; + + /* zero the accumulator registers */ + ZERO_ACC_ZMM_4_REG( zmm8, zmm9, zmm10, zmm11 ) + ZERO_ACC_ZMM_4_REG( zmm12, zmm13, zmm14, zmm15 ) + ZERO_ACC_ZMM_4_REG( zmm16, zmm17, zmm18, zmm19 ) + ZERO_ACC_ZMM_4_REG( zmm20, zmm21, zmm22, zmm23 ) + ZERO_ACC_XMM_4_REG( xmm0, xmm1, xmm2, xmm3 ) + + //update pointers + a_use = a + ir * rs_a; + b_use = b; + c_use = c + ir * rs_c; + + if( mr0 == MR ) + { + //Dot product kernel + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_si512( b_use ); + b_use += 64; + + //Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x64 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS( zmm28, zmm29, zmm30, + zmm31, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, zmm3, a_use, rs_a ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + + } // kloop + if( k_rem ) + { + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + //Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + + // Load 4x64 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm28, zmm29, zmm30, + zmm31, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use -= ( 12 * rs_a ); //Update aptr back to move horizontally + + + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm28, zmm29, zmm30, zmm31 + ) + LPGEMV_N_KERNEL_4_FMA( zmm20, zmm21, zmm22, zmm23, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + } + + //Add the registers horizantally to get one + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + //compose outputs into one zmm to perform post-ops + zmm8 = _mm512_inserti32x4 ( zmm8, xmm0, 0 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm1, 1 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm2, 2 ); + zmm8 = _mm512_inserti32x4 ( zmm8, xmm3, 3 ); + } + else + { + //Handle fringe cases when mr0 < MR + const uint8_t *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + // Dot product for mfringe 8 + if ( mr0_use >= 8 ) + { + // Dot product kernel for mr0 == 8 + for( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+31] + zmm6 = _mm512_loadu_si512( b_use ); + // move b pointer to next 64 elements + b_use += 64; + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + + // Load 4x64 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS( zmm24, zmm25, zmm26, + zmm27, a_use, rs_a + ) + a_use -= ( 4 * rs_a ); + + //Perform FMA on two 4x64 block of A with 64x1 + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + a_use += 64; + } + + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + a_use += ( 4 * rs_a ); + LPGEMV_N_KERNEL_4_MASKLOADS( zmm24, zmm25, zmm26, + zmm27, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm8, zmm9, zmm10, zmm11, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + LPGEMV_N_KERNEL_4_FMA( zmm12, zmm13, zmm14, zmm15, + zmm6, zmm24, zmm25, zmm26, zmm27 + ) + } + + // update pointers + mr0_use -= 8; + a_use = a_use_fringe + 8 * rs_a; + a_use_fringe = a_use; + b_use = b; + + // Horizontal add 8 zmm registers + // and get output into 2 xmm registers + LPGEMV_ZMM2XMM( zmm8, zmm9, zmm10, zmm11, + ymm0, ymm1, ymm2, ymm3, xmm0 + ) + LPGEMV_ZMM2XMM( zmm12, zmm13, zmm14, zmm15, + ymm4, ymm1, ymm2, ymm3, xmm1 + ) + + //insert xmm outputs into final output zmm8 reg + zmm8 = _mm512_inserti32x4( zmm8, xmm0, 0 ); + zmm8 = _mm512_inserti32x4( zmm8, xmm1, 1 ); + regidx = 2; + } + + // Dot product for mfringe 4 + if ( mr0_use >= 4 ) + { + // Dot product kernel for mr0 == 8 + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + + // move b pointer to next 64 elements + b_use += 64; + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS( zmm0, zmm1, zmm2, + zmm3, a_use, rs_a + ) + // Perform FMA on 4x64 block of A with 64x1 + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + a_use += 64; + } + + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + + // Load 4x64 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS( zmm0, zmm1, zmm2, + zmm3, k1, a_use, rs_a + ) + LPGEMV_N_KERNEL_4_FMA( zmm16, zmm17, zmm18, zmm19, + zmm6, zmm0, zmm1, zmm2, zmm3 + ) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm16, zmm17, zmm18, zmm19, + ymm5, ymm1, ymm2, ymm3, xmm2 + ) + + //insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) zmm8 = _mm512_inserti32x4( zmm8, xmm2, 0 ); + else zmm8 = _mm512_inserti32x4( zmm8, xmm2, 2 ); + regidx++; + } + + // Dot product for <= 3 + if ( mr0_use ) + { + // Dot product for m = 2 + if ( mr0_use >= 2 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + + // Load 2x64 elements from row0-row1 of A + zmm0 = _mm512_loadu_si512( a_use ); + zmm1 = _mm512_loadu_si512( a_use + rs_a ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm0, zmm6 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm1, zmm6 ); + + b_use += 64; // move b pointer to next 64 elements + a_use += 64; + } + if ( k_rem ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + zmm0 = _mm512_maskz_loadu_epi8( k1, a_use ); + zmm1 = _mm512_maskz_loadu_epi8( k1, a_use + rs_a ); + zmm20 = _mm512_dpbusd_epi32( zmm20, zmm0, zmm6 ); + zmm21 = _mm512_dpbusd_epi32( zmm21, zmm1, zmm6 ); + } + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = b; + } + + // Dot product for m = 2 + if ( mr0_use == 1 ) + { + for ( dim_t k = 0; k < k_iter; k++ ) + { + // Load 0-63 in b[k+0 - k+63] + zmm6 = _mm512_loadu_si512( b_use ); + zmm0 = _mm512_loadu_si512( a_use ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm0, zmm6 ); + b_use += 64; // move b pointer to next 64 elements + a_use += 64; + } + + if ( k_rem ) + { + zmm6 = _mm512_maskz_loadu_epi8( k1, b_use ); + zmm0 = _mm512_maskz_loadu_epi8( k1, a_use ); + zmm22 = _mm512_dpbusd_epi32( zmm22, zmm0, zmm6 ); + } + // When only fringe 1, + // update the registers to store in order + if ( !( mr0 & 0x2 ) ) zmm20 = zmm22; + } + + // Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM( zmm20, zmm21, zmm22, zmm23, + ymm6, ymm1, ymm2, ymm3, xmm3 + ) + + // insert xmm outputs into final output zmm8 reg based on regidx + if( regidx == 0 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 0 ); + } + else if( regidx == 1 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 1 ); + } + else if ( regidx == 2 ) + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 2 ); + } + else + { + zmm8 = _mm512_inserti32x4( zmm8, xmm3, 3 ); + } + } + } + + //Scale accumulated output with alpha + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + //Mulitply A*B output with alpha + zmm8 = _mm512_mullo_epi32( selector1, zmm8 ); + + if( beta != 0 ) + { + if( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + S8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0, + selector1, selector2 ) + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ); + } + selector1 = _mm512_cvtepi8_epi32 + ( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) ); + S32_BETA_FMA( zmm8, selector1, selector2 ); + } + } + else + { + if( rs_c == 1) + { + S32_S32_BETA_OP_NLT16F_MASK( c_use, k2, zmm8, 0, 0, 0, + selector1, selector2 ) + } + else + { + int32_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = c_use[ i * rs_c ]; + } + selector1 = _mm512_loadu_epi32( ctemp ); + S32_BETA_FMA( zmm8, selector1, selector2 ); + } + } + } + + // Post Ops + lpgemm_post_op *post_ops_list_temp = post_op; + + post_ops_attr.is_last_k = TRUE; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_set1_epi32( + *( ( int32_t* )post_ops_list_temp->op_args1) ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + zmm8 = _mm512_max_epi32( selector1, zmm8 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( + *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + RELU_SCALE_OP_S32_AVX512(zmm8) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64: + { + __m512 dn, z, x, r2, r, y, x_tanh; + GELU_TANH_S32_AVX512( zmm8, y, r, r2, x, + z, dn, x_tanh, selector1 ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64: + { + __m512 x, r, y, x_erf; + + GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64: + { + __m512i min = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args2 ); + __m512i max = _mm512_set1_epi32( + *( int32_t* )post_ops_list_temp->op_args3 ); + + CLIP_S32_AVX512( zmm8, min, max ) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6x64: + { + selector1 = ( __m512i )_mm512_set1_ps( + *( ( float* )post_ops_list_temp->scale_factor ) ); + + // Need to ensure sse not used to avoid avx512 -> sse transition. + __m128i zero_point0 = _mm512_castsi512_si128( + _mm512_setzero_si512() ); + + zero_point0 = _mm_maskz_set1_epi8( 0xFFFF, + *( ( int8_t* )post_ops_list_temp->op_args1 ) ); + + CVT_MULRND_CVT32(zmm8, selector1, zero_point0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64: + { + dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3; + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + S8_S32_MATRIX_ADD_LOAD( k2, selector1, 0, 0 ) + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + else + { + int8_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_cvtepi8_epi32 + ( _mm_maskz_loadu_epi8( k2, ctemp ) ); + + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + } + else + { + int32_t* matptr = ( int32_t* )post_ops_list_temp->op_args1; + + if( ldm == 1 ) + { + S32_S32_MATRIX_ADD_LOAD(k2, selector1, 0, 0 ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + else + { + int32_t ctemp[16]; + for( dim_t i = 0; i < mr0; i++ ) + { + ctemp[i] = *( matptr + + ( ( post_ops_attr.post_op_c_i + i ) + * ldm ) ); + } + selector1 = _mm512_maskz_loadu_epi32( k2, ctemp ); + zmm8 = _mm512_add_epi32( selector1, zmm8 ); + } + } + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + + POST_OPS_SWISH_6x64: + { + selector1 = + _mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) ); + + __m512 al = _mm512_cvtepi32_ps( selector1 ); + + __m512 fl_reg, al_in, r, r2, z, dn; + + SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, r, r2, z, dn, selector2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64_DISABLE: + { + // Case where the output C matrix is s8 (downscaled) and + // this is the final write for a given block within C. + if ( post_ops_attr.buf_downscale != NULL ) + { + if( post_ops_attr.rs_c_downscale == 1 ) + { + CVT_STORE_S32_S8_MASK( zmm8, k2, 0, 0 ); + } + else + { + int8_t ctemp[16]; + + _mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2, zmm8 ); + + for (dim_t i = 0; i < mr0; i++) + { + *( ( int8_t* )post_ops_attr.buf_downscale + + ( post_ops_attr.rs_c_downscale * + ( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i]; + } + } + } + else + { + if(rs_c == 1) + { + _mm512_mask_storeu_epi32(c_use, k2, zmm8); + } + else + { + // Store ZMM8 into ctemp buffer and store back + // element by element into output buffer at strides + int32_t ctemp[16]; + _mm512_mask_storeu_epi32(ctemp, k2, zmm8); + for (dim_t i = 0; i < mr0; i++) + { + c_use[i * rs_c] = ctemp[i]; + } + } + } + post_ops_attr.post_op_c_i += MR; + } + } +} + +#endif // BLIS_ADDON_LPGEMM diff --git a/kernels/zen5/3/bli_dgemm_avx512_asm_8x24.c b/kernels/zen5/3/bli_dgemm_avx512_asm_8x24.c new file mode 100644 index 0000000000..bb796c6fe8 --- /dev/null +++ b/kernels/zen5/3/bli_dgemm_avx512_asm_8x24.c @@ -0,0 +1,1303 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" +// BLIS_ASM_SYNTAX_INTEL syntax is followed + +/* + * Enable code to handle BETA = 0 and BETA = 1 + * Enabling this is causing regression when BETA is not equal + * 0 or 1, no improvement is observed when BETA = o or 1. + * Enabled this code for compliance reasons. + */ +#define BETA_OPTIMIZATION +#define ENABLE_COL_GEN_STORE + + +/* + * Prefetch distance for C + * TAIL_NITER = 26 is working better for single thread + * TAIL_NITER = 20 is working better for 128 threads + * TAIL_NITER = 24 used which gives good performance for 1 thread + * as well as 128 threads + * + * Prefetch C distance = TAIL_NITER + MR (24+8 = 32) + */ +#define TAIL_NITER 24 + +/* + * A_ADDITION is the negative offset added to address of A matrix + * so that the range of offsets for all references of A can be minimized + * in order to reduce the encoded instruction size. + * Max offset for A matrix will be := + * (MR*(UNROLL_FACTOR-1+ (MR+ number of A preloads))*sizeof(double) = 264 (used when + * SUBITER_1(3) macro is expanded ). + * Using A_ADDITION = 132 should reduce the instructions size + * the most, but A_ADDITION = 512 is giving better performance + * + * Similarly for B_ADDITION, max offset will be (24*3+16)*8 + 24*8*2 + * = 1088, therefore using B_ADDITION = 544 should reduce instruction + * size the most, but B_ADDITION = 1024 is giving better performance. + */ +#define A_ADDITION (512) +#define B_ADDITION (1024) + +#define LOOP_ALIGN ALIGN32 + + +/* + * Two different subiters(SUBITER_0 and SUBITER_1) are used + * so that latency of mov can be hidden + * SUBITER_0 laods B into ZMM0-2 + * SUBITER_0 laods B into ZMM3-5 + * SUBITER_0 and SUBITER_1 called alternatively + * + * ---------------------------------------------------------------- + * SUBITER_0 + * computes 8x24 block of C for one iteration of k loop + * parameters: n k index A(i,k) * B(k,j) + * Registers: rbx matrix b pointer + * rax matrix a pointer + * zmm6, zmm7 broadcast registers for a + * zmm0-zmm2 - 24 elements of "b" + * zmm8-zmm31 - stores a*b product + * -------------------------------------------------------------- +*/ +#define SUBITER_0(n) \ +\ + VFMADD231PD(ZMM( 8), ZMM(0), ZMM(6)) /*b(0 : 7, n) * a(n, 0) */\ + VFMADD231PD(ZMM( 9), ZMM(1), ZMM(6)) /*b(8 :15, n) * a(n, 0) */ \ + VFMADD231PD(ZMM(10), ZMM(2), ZMM(6)) /*b(16:23, n) * a(n, 0) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 2)*8 - A_ADDITION)) /*zmm6 = a(n, 2)*/ \ + VFMADD231PD(ZMM(11), ZMM(0), ZMM(7)) /*b(0 : 7, n) * a(n, 1) */\ + VFMADD231PD(ZMM(12), ZMM(1), ZMM(7)) /*b(8 :15, n) * a(n, 1) */ \ + VFMADD231PD(ZMM(13), ZMM(2), ZMM(7)) /*b(16:23, n) * a(n, 1) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 3)*8 - A_ADDITION)) /*zmm7 = a(n, 3)*/ \ + VFMADD231PD(ZMM(14), ZMM(0), ZMM(6)) /*b(0 : 7, n) * a(n, 2) */\ + VFMADD231PD(ZMM(15), ZMM(1), ZMM(6)) /*b(8 :15, n) * a(n, 2) */ \ + VFMADD231PD(ZMM(16), ZMM(2), ZMM(6)) /*b(16:23, n) * a(n, 2) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 4)*8 - A_ADDITION)) /*zmm6 = a(n, 4)*/ \ + VFMADD231PD(ZMM(17), ZMM(0), ZMM(7)) /*b(0 : 7, n) * a(n, 3) */\ + VFMADD231PD(ZMM(18), ZMM(1), ZMM(7)) /*b(8 :15, n) * a(n, 3) */ \ + VFMADD231PD(ZMM(19), ZMM(2), ZMM(7)) /*b(16:23, n) * a(n, 3) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 5)*8 - A_ADDITION)) /*zmm7 = a(n, 5)*/ \ + VFMADD231PD(ZMM(20), ZMM(0), ZMM(6)) /*b(0 : 7, n) * a(n, 4) */\ + VFMADD231PD(ZMM(21), ZMM(1), ZMM(6)) /*b(8 :15, n) * a(n, 4) */ \ + VFMADD231PD(ZMM(22), ZMM(2), ZMM(6)) /*b(16:23, n) * a(n, 4) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 6)*8 - A_ADDITION)) /*zmm6 = a(n, 6)*/ \ + VFMADD231PD(ZMM(23), ZMM(0), ZMM(7)) /*b(0 : 7, n) * a(n, 5) */\ + VFMADD231PD(ZMM(24), ZMM(1), ZMM(7)) /*b(8 :15, n) * a(n, 5) */ \ + VFMADD231PD(ZMM(25), ZMM(2), ZMM(7)) /*b(16:23, n) * a(n, 5) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 7)*8 - A_ADDITION)) /*zmm7 = a(n, 7)*/ \ + VFMADD231PD(ZMM(26), ZMM(0), ZMM(6)) /*b(0 : 7, n) * a(n, 6) */\ + VFMADD231PD(ZMM(27), ZMM(1), ZMM(6)) /*b(8 :15, n) * a(n, 6) */ \ + VFMADD231PD(ZMM(28), ZMM(2), ZMM(6)) /*b(16:23, n) * a(n, 6) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 8)*8 - A_ADDITION)) /*zmm6 = a(n+1, 0)*/\ + VFMADD231PD(ZMM(29), ZMM(0), ZMM(7)) /*b(0 : 7, n) * a(n, 7) */\ + VFMADD231PD(ZMM(30), ZMM(1), ZMM(7)) /*b(8 :15, n) * a(n, 7) */ \ + VFMADD231PD(ZMM(31), ZMM(2), ZMM(7)) /*b(16:23, n) * a(n, 7) */ \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 9)*8 - A_ADDITION)) /*zmm7 = a(n+1, 1)*/ \ + VMOVAPD(ZMM(0), MEM(RBX,(24*n+0 )*8 - B_ADDITION + 24*8*2))/*zmm0 = b(0 :7 , n+2)*/ \ + VMOVAPD(ZMM(1), MEM(RBX,(24*n+8 )*8 - B_ADDITION + 24*8*2))/*zmm1 = b(8 :15, n+2)*/ \ + VMOVAPD(ZMM(2), MEM(RBX,(24*n+16)*8 - B_ADDITION + 24*8*2))/*zmm2 = b(16:23, n+2)*/ \ + /*24*8*2 is preload offset compensated for B preload*/ \ +/* + * ---------------------------------------------------------------- + * SUBITER_1 + * computes 8x24 block of C for one iteration of k loop + * parameters: n k index A(i,k) * B(k,j) + * Registers: rbx matrix b pointer + * rax matrix a pointer + * zmm6, zmm7 broadcast registers for a + * zmm3-zmm5 - 24 elements of "b" + * zmm8-zmm31 - stores a*b product + * -------------------------------------------------------------- +*/ +#define SUBITER_1(n) \ +\ + VFMADD231PD(ZMM( 8), ZMM(3), ZMM(6)) /*b(0 : 7, n) * a(n, 0) */\ + VFMADD231PD(ZMM( 9), ZMM(4), ZMM(6)) /*b(8 :15, n) * a(n, 0) */ \ + VFMADD231PD(ZMM(10), ZMM(5), ZMM(6)) /*b(16:23, n) * a(n, 0) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 2)*8 - A_ADDITION)) /*zmm6 = a(n, 2)*/ \ + VFMADD231PD(ZMM(11), ZMM(3), ZMM(7)) /*b(0 : 7, n) * a(n, 1) */\ + VFMADD231PD(ZMM(12), ZMM(4), ZMM(7)) /*b(8 :15, n) * a(n, 1) */ \ + VFMADD231PD(ZMM(13), ZMM(5), ZMM(7)) /*b(16:23, n) * a(n, 1) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 3)*8 - A_ADDITION)) /*zmm7 = a(n, 3)*/ \ + VFMADD231PD(ZMM(14), ZMM(3), ZMM(6)) /*b(0 : 7, n) * a(n, 2) */\ + VFMADD231PD(ZMM(15), ZMM(4), ZMM(6)) /*b(8 :15, n) * a(n, 2) */ \ + VFMADD231PD(ZMM(16), ZMM(5), ZMM(6)) /*b(16:23, n) * a(n, 2) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 4)*8 - A_ADDITION)) /*zmm6 = a(n, 4)*/ \ + VFMADD231PD(ZMM(17), ZMM(3), ZMM(7)) /*b(0 : 7, n) * a(n, 3) */\ + VFMADD231PD(ZMM(18), ZMM(4), ZMM(7)) /*b(8 :15, n) * a(n, 3) */ \ + VFMADD231PD(ZMM(19), ZMM(5), ZMM(7)) /*b(16:23, n) * a(n, 3) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 5)*8 - A_ADDITION)) /*zmm7 = a(n, 5)*/ \ + VFMADD231PD(ZMM(20), ZMM(3), ZMM(6)) /*b(0 : 7, n) * a(n, 4) */\ + VFMADD231PD(ZMM(21), ZMM(4), ZMM(6)) /*b(8 :15, n) * a(n, 4) */ \ + VFMADD231PD(ZMM(22), ZMM(5), ZMM(6)) /*b(16:23, n) * a(n, 4) */ \ + \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 6)*8 - A_ADDITION)) /*zmm6 = a(n, 6)*/ \ + VFMADD231PD(ZMM(23), ZMM(3), ZMM(7)) /*b(0 : 7, n) * a(n, 5) */\ + VFMADD231PD(ZMM(24), ZMM(4), ZMM(7)) /*b(8 :15, n) * a(n, 5) */ \ + VFMADD231PD(ZMM(25), ZMM(5), ZMM(7)) /*b(16:23, n) * a(n, 5) */ \ + \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 7)*8 - A_ADDITION)) /*zmm7 = a(n, 7)*/ \ + VFMADD231PD(ZMM(26), ZMM(3), ZMM(6)) /*b(0 : 7, n) * a(n, 6) */\ + VFMADD231PD(ZMM(27), ZMM(4), ZMM(6)) /*b(8 :15, n) * a(n, 6) */ \ + VFMADD231PD(ZMM(28), ZMM(5), ZMM(6)) /*b(16:23, n) * a(n, 6) */ \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 8)*8 - A_ADDITION)) /*zmm6 = a(n+1, 0)*/ \ + \ + VFMADD231PD(ZMM(29), ZMM(3), ZMM(7)) /*b(0 : 7, n) * a(n, 7) */\ + VFMADD231PD(ZMM(30), ZMM(4), ZMM(7)) /*b(8 :15, n) * a(n, 7) */ \ + VFMADD231PD(ZMM(31), ZMM(5), ZMM(7)) /*b(16:23, n) * a(n, 7) */ \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 9)*8 - A_ADDITION)) /*zmm7 = a(n+1, 1)*/ \ + VMOVAPD(ZMM(3), MEM(RBX,(24*n+0 )*8 - B_ADDITION + 24*8*2))/*zmm3 = b(0 :7 , n+2)*/ \ + VMOVAPD(ZMM(4), MEM(RBX,(24*n+8 )*8 - B_ADDITION + 24*8*2))/*zmm4 = b(8 :15, n+2)*/ \ + VMOVAPD(ZMM(5), MEM(RBX,(24*n+16)*8 - B_ADDITION + 24*8*2))/*zmm5 = b(16:23, n+2)*/ \ + /*24*8*2 is preload offset compensated for B preload*/ \ + + +// Update C when C is general stored +#define UPDATE_C_SCATTERED(R1,R2,R3) \ +\ + KXNORW(K(1), K(0), K(0)) /*set mask register to zero*/ \ + KXNORW(K(2), K(0), K(0)) /*set mask register to zero*/ \ + KXNORW(K(3), K(0), K(0)) /*set mask register to zero*/ \ + VGATHERQPD(ZMM(0) MASK_K(1), MEM(RCX,ZMM(2),1)) /*load C(0:7) from current row of C*/\ + /*scale by beta*/ \ + VFMADD231PD(ZMM(R1), ZMM(0), ZMM(1)) /*zmmR1 += zmm0(C(0:7)*zmm1(beta)*/\ + VGATHERQPD(ZMM(0) MASK_K(2), MEM(RCX,ZMM(3),1)) /*load C(8:15)*/ \ + VFMADD231PD(ZMM(R2), ZMM(0), ZMM(1)) /*zmmR3 += zmm0(C(8:15)*zmm1(beta)*/\ + VGATHERQPD(ZMM(0) MASK_K(3), MEM(RCX,ZMM(4),1)) /*load C(16:23)*/ \ + VFMADD231PD(ZMM(R3), ZMM(0), ZMM(1)) /*zmmR3 += zmm0(C(16:23)*zmm1(beta)*/\ + /*mask registers are reset to 1 after gather/scatter instruction*/ \ + KXNORW(K(1), K(0), K(0)) /*set mask registers to zero*/\ + KXNORW(K(2), K(0), K(0)) \ + KXNORW(K(3), K(0), K(0)) \ + /*store c*/ \ + VSCATTERQPD(MEM(RCX,ZMM(2),1) MASK_K(1), ZMM(R1)) /*store C(0:7)*/ \ + VSCATTERQPD(MEM(RCX,ZMM(3),1) MASK_K(2), ZMM(R2)) /*store C(7:15)*/ \ + VSCATTERQPD(MEM(RCX,ZMM(4),1) MASK_K(3), ZMM(R3)) /*store C(16:23)*/ \ + LEA(RCX, MEM(RCX,R10,1)) + +// Update C when C is general stored and beta = 0 +#define UPDATE_C_SCATTERED_BZ(R1,R2,R3) \ +\ + KXNORW(K(1), K(0), K(0)) \ + KXNORW(K(2), K(0), K(0)) \ + KXNORW(K(3), K(0), K(0)) \ + VSCATTERQPD(MEM(RCX,ZMM(2),1) MASK_K(1), ZMM(R1)) \ + VSCATTERQPD(MEM(RCX,ZMM(3),1) MASK_K(2), ZMM(R2)) \ + VSCATTERQPD(MEM(RCX,ZMM(4),1) MASK_K(3), ZMM(R3)) \ + LEA(RCX, MEM(RCX,R10,1)) + +// 8x8 in register transpose, used for column stored C +#define TRANSPOSE_8X8(R0, R1, R2, R3, R4, R5, R6, R7) \ +\ + VUNPCKLPD(ZMM(6), ZMM(R0), ZMM(R1)) \ + VUNPCKLPD(ZMM(7), ZMM(R2), ZMM(R3)) \ + VUNPCKLPD(ZMM(2), ZMM(R4), ZMM(R5)) \ + VUNPCKLPD(ZMM(3), ZMM(R6), ZMM(R7)) \ + VMOVUPD(ZMM(0), ZMM(R0)) \ + VMOVUPD(ZMM(1), ZMM(R4)) \ + /*Stage2*/ \ + VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0x88)) \ + VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0x88)) \ + /*Stage3 1,5*/ \ + VSHUFF64X2(ZMM(R0), ZMM(4), ZMM(5), IMM(0x88)) \ + VSHUFF64X2(ZMM(R4), ZMM(4), ZMM(5), IMM(0xDD)) \ + /*Stage2*/ \ + VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0xDD)) \ + VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0xDD)) \ + /*Stage3 3,7*/ \ + VUNPCKHPD(ZMM(6), ZMM(0 ), ZMM(R1)) \ + VUNPCKHPD(ZMM(7), ZMM(R2), ZMM(R3)) \ + VUNPCKHPD(ZMM(2), ZMM(1 ), ZMM(R5)) \ + VUNPCKHPD(ZMM(3), ZMM(R6), ZMM(R7)) \ + VSHUFF64X2(ZMM(R2), ZMM(4), ZMM(5), IMM(0x88)) \ + VSHUFF64X2(ZMM(R6), ZMM(4), ZMM(5), IMM(0xDD)) \ + \ + /*Stage2*/ \ + VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0x88)) \ + VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0x88)) \ + /*Stage3 2,6*/ \ + VSHUFF64X2(ZMM(R1), ZMM(4), ZMM(5), IMM(0x88)) \ + VSHUFF64X2(ZMM(R5), ZMM(4), ZMM(5), IMM(0xDD)) \ + /*Stage2*/ \ + VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0xDD)) \ + VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0xDD)) \ + /*Stage3 4,8*/ \ + VSHUFF64X2(ZMM(R3), ZMM(4), ZMM(5), IMM(0x88)) \ + VSHUFF64X2(ZMM(R7), ZMM(4), ZMM(5), IMM(0xDD)) \ + +// Update C when C is column stored +#define UPDATE_C_COL_STORE(R0, R1, R2, R3, R4, R5, R6, R7) \ + \ + /* scale by alpha */\ + VMULPD(ZMM(R0), ZMM(R0), ZMM(0)) \ + VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \ + VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \ + VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \ + VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \ + VMULPD(ZMM(R5), ZMM(R5), ZMM(0)) \ + VMULPD(ZMM(R6), ZMM(R6), ZMM(0)) \ + VMULPD(ZMM(R7), ZMM(R7), ZMM(0)) \ + /*scale by beta*/\ + VFMADD231PD(ZMM(R0), ZMM(1), MEM(RCX)) \ + /*store c*/ \ + VMOVUPD(MEM(RCX), ZMM(R0)) \ + VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX, R12, 1)) \ + VMOVUPD(MEM(RCX, R12, 1), ZMM(R1)) \ + VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX, R12, 2)) \ + VMOVUPD(MEM(RCX, R12, 2), ZMM(R2)) \ + VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX, R13, 1)) \ + VMOVUPD(MEM(RCX, R13, 1), ZMM(R3)) \ + VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX, R12, 4)) \ + VMOVUPD(MEM(RCX, R12, 4), ZMM(R4)) \ + VFMADD231PD(ZMM(R5), ZMM(1), MEM(RCX, RDX, 1)) \ + VMOVUPD(MEM(RCX, RDX, 1), ZMM(R5)) \ + VFMADD231PD(ZMM(R6), ZMM(1), MEM(RCX, R13, 2)) \ + VMOVUPD(MEM(RCX, R13, 2), ZMM(R6)) \ + VFMADD231PD(ZMM(R7), ZMM(1), MEM(RCX, R14, 1)) \ + VMOVUPD(MEM(RCX, R14, 1), ZMM(R7)) \ + LEA(RCX, MEM(RCX,R12,8)) + +// Update C when C is column stored and beta = 0 +#define UPDATE_C_COL_STORE_BZ(R0, R1, R2, R3, R4, R5, R6, R7) \ + /* scale by alpha */\ + VMULPD(ZMM(R0), ZMM(R0), ZMM(0)) \ + VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \ + VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \ + VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \ + VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \ + VMULPD(ZMM(R5), ZMM(R5), ZMM(0)) \ + VMULPD(ZMM(R6), ZMM(R6), ZMM(0)) \ + VMULPD(ZMM(R7), ZMM(R7), ZMM(0)) \ + /*store c*/ \ + VMOVUPD(MEM(RCX), ZMM(R0)) \ + VMOVUPD(MEM(RCX, R12, 1), ZMM(R1)) /*R12 = cs_c*/ \ + VMOVUPD(MEM(RCX, R12, 2), ZMM(R2)) \ + VMOVUPD(MEM(RCX, R13, 1), ZMM(R3)) /*R13 = 3*cs_c*/\ + VMOVUPD(MEM(RCX, R12, 4), ZMM(R4)) \ + VMOVUPD(MEM(RCX, RDX, 1), ZMM(R5)) /*RDX = 5*cs_c*/\ + VMOVUPD(MEM(RCX, R13, 2), ZMM(R6)) \ + VMOVUPD(MEM(RCX, R14, 1), ZMM(R7)) /*R14 = 7*cs_c*/\ + LEA(RCX, MEM(RCX,R12,8)) + +#define ZERO_REGISTERS() \ + VXORPD(ZMM(8) , ZMM(8), ZMM(8)) \ + VXORPD(ZMM(9) , ZMM(9), ZMM(9)) \ + VXORPD(ZMM(10), ZMM(10), ZMM(10)) \ + VXORPD(ZMM(11), ZMM(11), ZMM(11)) \ + VXORPD(ZMM(12), ZMM(12), ZMM(12)) \ + VXORPD(ZMM(13), ZMM(13), ZMM(13)) \ + VXORPD(ZMM(14), ZMM(14), ZMM(14)) \ + VXORPD(ZMM(15), ZMM(15), ZMM(15)) \ + VXORPD(ZMM(16), ZMM(16), ZMM(16)) \ + VXORPD(ZMM(17), ZMM(17), ZMM(17)) \ + VXORPD(ZMM(18), ZMM(18), ZMM(18)) \ + VXORPD(ZMM(19), ZMM(19), ZMM(19)) \ + VXORPD(ZMM(20), ZMM(20), ZMM(20)) \ + VXORPD(ZMM(21), ZMM(21), ZMM(21)) \ + VXORPD(ZMM(22), ZMM(22), ZMM(22)) \ + VXORPD(ZMM(23), ZMM(23), ZMM(23)) \ + VXORPD(ZMM(24), ZMM(24), ZMM(24)) \ + VXORPD(ZMM(25), ZMM(25), ZMM(25)) \ + VXORPD(ZMM(26), ZMM(26), ZMM(26)) \ + VXORPD(ZMM(27), ZMM(27), ZMM(27)) \ + VXORPD(ZMM(28), ZMM(28), ZMM(28)) \ + VXORPD(ZMM(29), ZMM(29), ZMM(29)) \ + VXORPD(ZMM(30), ZMM(30), ZMM(30)) \ + VXORPD(ZMM(31), ZMM(31), ZMM(31)) + +#define K_LOOP() \ + /* pre-load two rows of B */ \ + VMOVAPD(ZMM(0), MEM(RBX, 0*8)) /* zmm0 = row - b[k - 0:7] */ \ + VMOVAPD(ZMM(1), MEM(RBX, 8*8)) /* zmm1 = row - b[k - 8:15] */ \ + VMOVAPD(ZMM(2), MEM(RBX,16*8)) /* zmm2 = row - b[k - 16:23] */ \ + \ + VMOVAPD(ZMM(3), MEM(RBX,24*8)) /* zmm3 = row - b[k+1 - 24:31] */ \ + VMOVAPD(ZMM(4), MEM(RBX,32*8)) /* zmm4 = row - b[k+1 - 32:39] */ \ + VMOVAPD(ZMM(5), MEM(RBX,40*8)) /* zmm5 = row - b[k+1 - 40:48] */ \ + \ + /* pre-load A */ \ + VBROADCASTSD(ZMM(6), MEM(RAX,(8*0+0)*8)) /* zmm6 = a[0] */ \ + VBROADCASTSD(ZMM(7), MEM(RAX,(8*0+1)*8)) /* zmm7 = a[1] */ \ + \ + /* move address of A and B forward so that negative addresses */ \ + /* can be used */ \ + ADD(RBX, IMM( 0+B_ADDITION )) /* A += A_ADDITION */ \ + ADD(RAX, IMM( 0+A_ADDITION )) /* B += B_ADDITION */ \ + \ + MOV(R13, RDX) /* R14 = k */ \ + MOV(R14, RDX) /* R14 = k */ \ + AND(R14, IMM(3)) /* R14(k_left) = k & 3, R14 = k % 4 */ \ + SAR(R13, IMM(2)) /* R13(k_iter) = k >> 2, R13 = k / 4 */ \ + \ + SUB(R13, IMM(8+TAIL_NITER)) /* k/4 - MR - TAIL_NITER, MR = 8 */ \ + JLE(K_PREFETCH) /* jump to C prefetch loop if k_iter <= 0 */ \ + /* LABEL(K_MAIN)*/ \ + \ + LOOP_ALIGN \ + LABEL(LOOP1) \ + \ + SUBITER_0(0) /* k=0 */ \ + SUBITER_1(1) /* k=1 */ \ + SUBITER_0(2) /* k=2 */ \ + SUBITER_1(3) /* k=3 */ \ + \ + LEA(RAX, MEM(RAX,4*8*8)) /* rax -> (UNROLL_FACTOR * MR * sizeof(double)) next 4th col of a */ \ + LEA(RBX, MEM(RBX,4*24*8)) /* rbx -> (UNROLL_FACTOR * NR * sizeof(double)) next 4th row of b */ \ + DEC(R13) /* R13-=1 */ \ + \ + JNZ(LOOP1) /* if R13 != 0 jump to loop1 */ \ + \ + LABEL(K_PREFETCH) \ + \ + ADD(R13, IMM(8)) /* add prefetch loop count ( R13(k_iter) += MR ) */ \ + JLE(K_TAIL) /* jump to tail iteration if k_iter <= 0 */ \ + \ + LOOP_ALIGN \ + /* MR * 24 block of c is prefetched */ \ + LABEL(LOOP2) \ + \ + PREFETCHW0(MEM(R12)) /* prefetch row - C[k, 0:7] */ \ + SUBITER_0(0) /* k=0 */ \ + PREFETCHW0(MEM(R12,8*8)) /* prefetch row - C[k, 8:15] */ \ + SUBITER_1(1) /* k=1 */ \ + PREFETCHW0(MEM(R12,16*8)) /* prefetch row - C[k, 16:23] */ \ + SUBITER_0(2) /* k=2 */ \ + SUBITER_1(3) /* k=3 */ \ + \ + LEA(RAX, MEM(RAX,4*8*8)) /* rax -> (UNROLL_FACTOR * MR * sizeof(double)) next 4th col of a */ \ + LEA(RBX, MEM(RBX,4*24*8)) /* rbx -> (UNROLL_FACTOR * NR * sizeof(double)) next 4th row of b */ \ + LEA(R12, MEM(R12,R10,1)) /* R12 -> c += ldc (next row of c) */ \ + DEC(R13) /* R13-=1 */ \ + \ + JNZ(LOOP2) /* if R13 != 0 jump to loop2 */ \ + \ + LABEL(K_TAIL) \ + \ + ADD(R13, IMM(0+TAIL_NITER)) /* R13(k_iter) += TAIL_ITER */ \ + JLE(POST_K) /* jump to TAIL loop if k_iter <= 0 */ \ + \ + LOOP_ALIGN \ + LABEL(LOOP3) \ + \ + SUBITER_0(0) /* k=0 */ \ + SUBITER_1(1) /* k=1 */ \ + SUBITER_0(2) /* k=2 */ \ + SUBITER_1(3) /* k=3 */ \ + \ + LEA(RAX, MEM(RAX,4*8*8)) /* rax -> next 4th col of a*/ \ + LEA(RBX, MEM(RBX,4*24*8)) /* rbx -> next 4th row of b*/ \ + DEC(R13) /* R13-=1 */ \ + \ + JNZ(LOOP3) /* if R13 != 0 jump to LOOP3 */ \ + \ + LABEL(POST_K) \ + \ + TEST(R14, R14) \ + JZ(POSTACCUM) \ + /* Only SUBITER_0 is used in this loop, */ \ + /* therefore negative offset is done for 1 iter */ \ + /* of K only(24*8) */ \ + SUB(RBX, IMM(24*8)) /* rbx -> prev 4th row of b */ \ + LOOP_ALIGN \ + LABEL(LOOP4) \ + \ + SUBITER_0(0) /*k=0 */ \ + \ + LEA(RAX, MEM(RAX,8*8)) /* rax -> (UNROLL_FACTOR(1) * MR * sizeof(double)) next col of a */ \ + LEA(RBX, MEM(RBX,24*8)) /* rbx -> (UNROLL_FACTOR(1) * NR * sizeof(double)) next row of b */ \ + DEC(R14) \ + \ + JNZ(LOOP4) + + +//This is an array used for the scatter/gather instructions. +static int64_t offsets[24] __attribute__((aligned(64))) = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22,23}; + + +/* + * number of accumulation registers = 24/8 * 8 = 24 zmm8 to zmm31 + * number of registers used for load B = + * 24/8 = 3 (*2 for hiding load latency) zmm0 to zmm5 + * number of registers used for broadcast A = 2 zmm6 and zmm7 + */ +void bli_dgemm_avx512_asm_8x24( + dim_t k_, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) +{ + (void)data; + (void)cntx; + (void)cs_c_; + + const int64_t* offsetPtr = &offsets[0]; + const int64_t k = k_; + const int64_t rs_c = rs_c_*8; //convert strides to bytes + const int64_t cs_c = cs_c_*8; //convert strides to bytes + + + BEGIN_ASM() + + ZERO_REGISTERS() + MOV(RDX, VAR(k)) // loop index + MOV(RAX, VAR(a)) // load address of a + MOV(RBX, VAR(b)) // load address of b + MOV(RCX, VAR(c)) // load address of c + MOV(R10, VAR(rs_c)) // load rs_c + + LEA(R12, MEM(RCX,63)) // c for prefetching R12 := C + cacheline_offset + + K_LOOP() + + LABEL(POSTACCUM) + + MOV(RAX, VAR(alpha)) + MOV(RBX, VAR(beta)) + VBROADCASTSD(ZMM(0), MEM(RAX)) // broadcast alpha into zmm0 + + // R10 = rs_c + LEA(R13, MEM(R10, R10, 2)) // (R13)rs_c*3 -> rs_c + rs_c*2 + LEA(RDX, MEM(R10, R10, 4)) // (RDX)rs_c*5 -> rs_c + rs_c*4 + LEA(R14, MEM(R10, R13, 2)) // (R14)rs_c*7 -> rs_c + rs_c*3*2 + +#ifdef ENABLE_COL_GEN_STORE + MOV(R12, VAR(cs_c)) // load cs_c + CMP(R10, IMM(8)) + JE(COLUPDATE) // jump to COLUPDATE if rs_c(R10) == 1 + + CMP(R12, IMM(8)) // R12 = cs_c + JNE(SCATTERUPDATE) // if cs_c(R12) != 1 jump to scatterupdate +#endif + +#ifdef BETA_OPTIMIZATION // if beta = 0 and beta = 1 are handled + MOV(RAX, IMM(1)) + CVTSI2SD(XMM(3), RAX) + + MOV(RAX, VAR(alpha)) + + VXORPD(ZMM(2), ZMM(2), ZMM(2)) + VBROADCASTSD(ZMM(1), MEM(RBX)) + + VCOMISD(XMM(1), XMM(2)) + JZ(BETA_ZERO) // jump to BETA_ZERO if beta == 0 + + VCOMISD(XMM(1), XMM(3)) + CMP(RBX, IMM(1)) + JNZ(BETA_NZ_N1)// jump to BETA_NZ_N1 if beta != 1 + + // no jumps for beta = 1 + // LABEL(BETA_ONE) + + // row0 + // scale by alpha, zmm0 = alpha + VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) // zmm8 *= alpha + VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) // zmm9 *= alpha + VMULPD(ZMM(10), ZMM(10), ZMM(0)) // zmm10*= alpha + /*since beta == 1, C += alpha(AB)*/ + VADDPD(ZMM( 8), ZMM( 8), MEM(RCX)) // zmm8 = C(0 :7 ) + zmm8 *alpha + VADDPD(ZMM( 9), ZMM( 9), MEM(RCX,64)) // zmm9 = C(8 :15) + zmm9 *alpha + VADDPD(ZMM(10), ZMM(10), MEM(RCX,128)) // zmm10= C(16:23) + zmm10*alpha + /*store c*/ + VMOVUPD(MEM(RCX ), ZMM( 8)) // C(0 :7 ) = zmm8 + VMOVUPD(MEM(RCX, 64), ZMM( 9)) // C(8 :15) = zmm9 + VMOVUPD(MEM(RCX,128), ZMM(10)) // C(16:23) = zmm10 + + // row1 + VMULPD(ZMM(11), ZMM(11), ZMM(0)) // zmm11 *= alpha + VMULPD(ZMM(12), ZMM(12), ZMM(0)) // zmm12 *= alpha + VMULPD(ZMM(13), ZMM(13), ZMM(0)) // zmm13 *= alpha + /*scale by beta*/ + VADDPD(ZMM(11), ZMM(11), MEM(RCX, R10, 1 )) // zmm11= C(0 :7 ) + zmm11*alpha + VADDPD(ZMM(12), ZMM(12), MEM(RCX, R10, 1, 64 )) // zmm12= C(8 :15) + zmm12*alpha + VADDPD(ZMM(13), ZMM(13), MEM(RCX, R10, 1, 128)) // zmm13= C(16:23) + zmm13*alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 1 ), ZMM(11)) // C(0 :7 ) = zmm11 + VMOVUPD(MEM(RCX, R10, 1, 64 ), ZMM(12)) // C(8 :15) = zmm12 + VMOVUPD(MEM(RCX, R10, 1, 128), ZMM(13)) // C(16:23) = zmm13 + + // row2 + VMULPD(ZMM(14), ZMM(14), ZMM(0)) // zmm14 *= alpha + VMULPD(ZMM(15), ZMM(15), ZMM(0)) // zmm15 *= alpha + VMULPD(ZMM(16), ZMM(16), ZMM(0)) // zmm16 *= alpha + /*scale by beta*/ + VADDPD(ZMM(14), ZMM(14), MEM(RCX, R10, 2 )) // zmm14 = C(0 :7 ) + zmm14 *alpha + VADDPD(ZMM(15), ZMM(15), MEM(RCX, R10, 2, 64 )) // zmm15 = C(8 :15) + zmm15 *alpha + VADDPD(ZMM(16), ZMM(16), MEM(RCX, R10, 2, 128)) // zmm16 = C(16:23) + zmm16 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 2 ), ZMM(14)) // C(0 :7 ) = zmm14 + VMOVUPD(MEM(RCX, R10, 2, 64 ), ZMM(15)) // C(8 :15) = zmm15 + VMOVUPD(MEM(RCX, R10, 2, 128), ZMM(16)) // C(16:23) = zmm16 + + // row3 + VMULPD(ZMM(17), ZMM(17), ZMM(0)) // zmm17 *= alpha + VMULPD(ZMM(18), ZMM(18), ZMM(0)) // zmm18 *= alpha + VMULPD(ZMM(19), ZMM(19), ZMM(0)) // zmm19 *= alpha + /*scale by beta*/ + VADDPD(ZMM(17), ZMM(17), MEM(RCX, R13, 1 )) // zmm17 = C(0 :7 ) + zmm17 *alpha + VADDPD(ZMM(18), ZMM(18), MEM(RCX, R13, 1, 64 )) // zmm18 = C(8 :15) + zmm18 *alpha + VADDPD(ZMM(19), ZMM(19), MEM(RCX, R13, 1, 128)) // zmm18 = C(16:23) + zmm18 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, R13, 1 ), ZMM(17)) // C(0 :7 ) = zmm17 + VMOVUPD(MEM(RCX, R13, 1, 64 ), ZMM(18)) // C(8 :15) = zmm18 + VMOVUPD(MEM(RCX, R13, 1, 128), ZMM(19)) // C(16:23) = zmm18 + + // row4 + VMULPD(ZMM(20), ZMM(20), ZMM(0)) // zmm20 *= alpha + VMULPD(ZMM(21), ZMM(21), ZMM(0)) // zmm21 *= alpha + VMULPD(ZMM(22), ZMM(22), ZMM(0)) // zmm22 *= alpha + /*scale by beta*/ + VADDPD(ZMM(20), ZMM(20), MEM(RCX, R10, 4 )) // zmm20 = C(0 :7 ) + zmm20 *alpha + VADDPD(ZMM(21), ZMM(21), MEM(RCX, R10, 4, 64 )) // zmm21 = C(8 :15) + zmm21 *alpha + VADDPD(ZMM(22), ZMM(22), MEM(RCX, R10, 4, 128)) // zmm22 = C(16:23) + zmm22 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 4 ), ZMM(20)) // C(0 :7 ) = zmm20 + VMOVUPD(MEM(RCX, R10, 4, 64 ), ZMM(21)) // C(8 :15) = zmm21 + VMOVUPD(MEM(RCX, R10, 4, 128), ZMM(22)) // C(16:23) = zmm22 + + // row5 + VMULPD(ZMM(23), ZMM(23), ZMM(0)) // zmm23 *= alpha + VMULPD(ZMM(24), ZMM(24), ZMM(0)) // zmm24 *= alpha + VMULPD(ZMM(25), ZMM(25), ZMM(0)) // zmm25 *= alpha + /*scale by beta*/ + VADDPD(ZMM(23), ZMM(23), MEM(RCX, RDX, 1 )) // zmm23 = C(0 :7 ) + zmm23 *alpha + VADDPD(ZMM(24), ZMM(24), MEM(RCX, RDX, 1, 64 )) // zmm24 = C(8 :15) + zmm24 *alpha + VADDPD(ZMM(25), ZMM(25), MEM(RCX, RDX, 1, 128)) // zmm25 = C(16:23) + zmm25 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, RDX, 1 ), ZMM(23)) // C(0 :7 ) = zmm23 + VMOVUPD(MEM(RCX, RDX, 1, 64 ), ZMM(24)) // C(8 :15) = zmm24 + VMOVUPD(MEM(RCX, RDX, 1, 128), ZMM(25)) // C(16:23) = zmm25 + + // row6 + VMULPD(ZMM(26), ZMM(26), ZMM(0)) // zmm26 *= alpha + VMULPD(ZMM(27), ZMM(27), ZMM(0)) // zmm27 *= alpha + VMULPD(ZMM(28), ZMM(28), ZMM(0)) // zmm28 *= alpha + /*scale by beta*/ + VADDPD(ZMM(26), ZMM(26), MEM(RCX, R13, 2 )) // zmm26 = C(0 :7 ) + zmm26 *alpha + VADDPD(ZMM(27), ZMM(27), MEM(RCX, R13, 2, 64 )) // zmm27 = C(8 :15) + zmm27 *alpha + VADDPD(ZMM(28), ZMM(28), MEM(RCX, R13, 2, 128)) // zmm28 = C(16:23) + zmm28 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, R13, 2 ), ZMM(26)) // C(0 :7 ) = zmm26 + VMOVUPD(MEM(RCX, R13, 2, 64 ), ZMM(27)) // C(8 :15) = zmm27 + VMOVUPD(MEM(RCX, R13, 2, 128), ZMM(28)) // C(16:23) = zmm28 + + // row6 + VMULPD(ZMM(29), ZMM(29), ZMM(0)) // zmm29 *= alpha + VMULPD(ZMM(30), ZMM(30), ZMM(0)) // zmm30 *= alpha + VMULPD(ZMM(31), ZMM(31), ZMM(0)) // zmm31 *= alpha + /*scale by beta*/ + VADDPD(ZMM(29), ZMM(29), MEM(RCX, R14, 1 )) // zmm29 = C(0 :7 ) + zmm29 *alpha + VADDPD(ZMM(30), ZMM(30), MEM(RCX, R14, 1, 64 )) // zmm30 = C(8 :15) + zmm30 *alpha + VADDPD(ZMM(31), ZMM(31), MEM(RCX, R14, 1, 128)) // zmm31 = C(16:23) + zmm31 *alpha + /*store c*/ + VMOVUPD(MEM(RCX, R14, 1 ), ZMM(29)) // C(0 :7 ) = zmm29 + VMOVUPD(MEM(RCX, R14, 1, 64 ), ZMM(30)) // C(8 :15) = zmm30 + VMOVUPD(MEM(RCX, R14, 1, 128), ZMM(31)) // C(16:23) = zmm31 + JMP(END) + + LABEL(BETA_ZERO) + // row0 + VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) // zmm8 *= alpha + VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) // zmm9 *= alpha + VMULPD(ZMM(10), ZMM(10), ZMM(0)) // zmm10 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX ), ZMM( 8)) // C(0 :7 ) = zmm8 + VMOVUPD(MEM(RCX, 64), ZMM( 9)) // C(7 :15) = zmm9 + VMOVUPD(MEM(RCX,128), ZMM(10)) // C(16:23) = zmm10 + + // row1 + VMULPD(ZMM(11), ZMM(11), ZMM(0)) // zmm11 *= alpha + VMULPD(ZMM(12), ZMM(12), ZMM(0)) // zmm12 *= alpha + VMULPD(ZMM(13), ZMM(13), ZMM(0)) // zmm13 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 1 ), ZMM(11)) // C(0 :7 ) = zmm11 + VMOVUPD(MEM(RCX, R10, 1, 64 ), ZMM(12)) // C(7 :15) = zmm12 + VMOVUPD(MEM(RCX, R10, 1, 128), ZMM(13)) // C(16:23) = zmm13 + + // row2 + VMULPD(ZMM(14), ZMM(14), ZMM(0)) // zmm14 *= alpha + VMULPD(ZMM(15), ZMM(15), ZMM(0)) // zmm15 *= alpha + VMULPD(ZMM(16), ZMM(16), ZMM(0)) // zmm16 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 2 ), ZMM(14)) // C(0 :7 ) = zmm14 + VMOVUPD(MEM(RCX, R10, 2, 64 ), ZMM(15)) // C(7 :15) = zmm15 + VMOVUPD(MEM(RCX, R10, 2, 128), ZMM(16)) // C(16:23) = zmm16 + + // row3 + VMULPD(ZMM(17), ZMM(17), ZMM(0)) // zmm17 *= alpha + VMULPD(ZMM(18), ZMM(18), ZMM(0)) // zmm18 *= alpha + VMULPD(ZMM(19), ZMM(19), ZMM(0)) // zmm19 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R13, 1 ), ZMM(17)) // C(0 :7 ) = zmm17 + VMOVUPD(MEM(RCX, R13, 1, 64 ), ZMM(18)) // C(7 :15) = zmm18 + VMOVUPD(MEM(RCX, R13, 1, 128), ZMM(19)) // C(16:23) = zmm19 + + // row4 + VMULPD(ZMM(20), ZMM(20), ZMM(0)) // zmm20 *= alpha + VMULPD(ZMM(21), ZMM(21), ZMM(0)) // zmm21 *= alpha + VMULPD(ZMM(22), ZMM(22), ZMM(0)) // zmm22 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R10, 4 ), ZMM(20)) // C(0 :7 ) = zmm20 + VMOVUPD(MEM(RCX, R10, 4, 64 ), ZMM(21)) // C(7 :15) = zmm21 + VMOVUPD(MEM(RCX, R10, 4, 128), ZMM(22)) // C(16:23) = zmm22 + + // row5 + VMULPD(ZMM(23), ZMM(23), ZMM(0)) // zmm23 *= alpha + VMULPD(ZMM(24), ZMM(24), ZMM(0)) // zmm24 *= alpha + VMULPD(ZMM(25), ZMM(25), ZMM(0)) // zmm25 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, RDX, 1 ), ZMM(23)) // C(0 :7 ) = zmm23 + VMOVUPD(MEM(RCX, RDX, 1, 64 ), ZMM(24)) // C(7 :15) = zmm24 + VMOVUPD(MEM(RCX, RDX, 1, 128), ZMM(25)) // C(16:23) = zmm25 + + // row6 + VMULPD(ZMM(26), ZMM(26), ZMM(0)) // zmm26 *= alpha + VMULPD(ZMM(27), ZMM(27), ZMM(0)) // zmm27 *= alpha + VMULPD(ZMM(28), ZMM(28), ZMM(0)) // zmm28 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R13, 2 ), ZMM(26)) // C(0 :7 ) = zmm26 + VMOVUPD(MEM(RCX, R13, 2, 64 ), ZMM(27)) // C(7 :15) = zmm27 + VMOVUPD(MEM(RCX, R13, 2, 128), ZMM(28)) // C(16:23) = zmm28 + + // row6 + VMULPD(ZMM(29), ZMM(29), ZMM(0)) // zmm29 *= alpha + VMULPD(ZMM(30), ZMM(30), ZMM(0)) // zmm30 *= alpha + VMULPD(ZMM(31), ZMM(31), ZMM(0)) // zmm31 *= alpha + /*store c*/ + VMOVUPD(MEM(RCX, R14, 1 ), ZMM(29)) // C(0 :7 ) = zmm29 + VMOVUPD(MEM(RCX, R14, 1, 64 ), ZMM(30)) // C(7 :15) = zmm30 + VMOVUPD(MEM(RCX, R14, 1, 128), ZMM(31)) // C(16:23) = zmm31 + + JMP(END) + + LABEL(BETA_NZ_N1) // beta not zero or not 1 +#endif //BETA_OPTIMIZATION + VBROADCASTSD(ZMM(1), MEM(RBX)) // broadcast beta to zmm1 + + // row0 + VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) // zmm8 *= alpha + VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) // zmm9 *= alpha + VMULPD(ZMM(10), ZMM(10), ZMM(0)) // zmm10 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM( 8), ZMM(1), MEM(RCX)) // zmm8 = zmm1*C(0 :7 ) + zmm8, zmm8 = beta*C(0 :7 ) + zmm8 + VFMADD231PD(ZMM( 9), ZMM(1), MEM(RCX,64)) // zmm9 = zmm1*C(8 :15) + zmm9 + VFMADD231PD(ZMM(10), ZMM(1), MEM(RCX,128)) // zmm10 = zmm1*C(16:23) + zmm10 + /*store c*/ + VMOVUPD(MEM(RCX ), ZMM( 8)) // C(0 :7 ) = zmm8 + VMOVUPD(MEM(RCX, 64), ZMM( 9)) // C(7 :15) = zmm9 + VMOVUPD(MEM(RCX,128), ZMM(10)) // C(16:23) = zmm10 + + // row1 + VMULPD(ZMM(11), ZMM(11), ZMM(0)) // zmm11 *= alpha + VMULPD(ZMM(12), ZMM(12), ZMM(0)) // zmm12 *= alpha + VMULPD(ZMM(13), ZMM(13), ZMM(0)) // zmm13 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(11), ZMM(1), MEM(RCX, R10, 1 )) // zmm11 = zmm1*C(0 :7 ) + zmm11 + VFMADD231PD(ZMM(12), ZMM(1), MEM(RCX, R10, 1, 64 )) // zmm12 = zmm1*C(8 :15) + zmm12 + VFMADD231PD(ZMM(13), ZMM(1), MEM(RCX, R10, 1, 128)) // zmm13 = zmm1*C(16:23) + zmm13 + /*store c*/ + VMOVUPD(MEM(RCX, R10, 1 ), ZMM(11)) // C(0 :7 ) = zmm11 + VMOVUPD(MEM(RCX, R10, 1, 64 ), ZMM(12)) // C(7 :15) = zmm12 + VMOVUPD(MEM(RCX, R10, 1, 128), ZMM(13)) // C(16:23) = zmm13 + + // row2 + VMULPD(ZMM(14), ZMM(14), ZMM(0)) // zmm14 *= alpha + VMULPD(ZMM(15), ZMM(15), ZMM(0)) // zmm15 *= alpha + VMULPD(ZMM(16), ZMM(16), ZMM(0)) // zmm16 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(14), ZMM(1), MEM(RCX, R10, 2 )) // zmm14 = zmm1*C(0 :7 ) + zmm14 + VFMADD231PD(ZMM(15), ZMM(1), MEM(RCX, R10, 2, 64 )) // zmm15 = zmm1*C(8 :15) + zmm15 + VFMADD231PD(ZMM(16), ZMM(1), MEM(RCX, R10, 2, 128)) // zmm16 = zmm1*C(16:23) + zmm16 + /*store c*/ + VMOVUPD(MEM(RCX, R10, 2 ), ZMM(14)) // C(0 :7 ) = zmm14 + VMOVUPD(MEM(RCX, R10, 2, 64 ), ZMM(15)) // C(7 :15) = zmm15 + VMOVUPD(MEM(RCX, R10, 2, 128), ZMM(16)) // C(16:23) = zmm16 + + // row3 + VMULPD(ZMM(17), ZMM(17), ZMM(0)) // zmm17 *= alpha + VMULPD(ZMM(18), ZMM(18), ZMM(0)) // zmm18 *= alpha + VMULPD(ZMM(19), ZMM(19), ZMM(0)) // zmm19 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(17), ZMM(1), MEM(RCX, R13, 1 )) // zmm17 = zmm1*C(0 :7 ) + zmm17 + VFMADD231PD(ZMM(18), ZMM(1), MEM(RCX, R13, 1, 64 )) // zmm18 = zmm1*C(8 :15) + zmm18 + VFMADD231PD(ZMM(19), ZMM(1), MEM(RCX, R13, 1, 128)) // zmm19 = zmm1*C(16:23) + zmm19 + /*store c*/ + VMOVUPD(MEM(RCX, R13, 1 ), ZMM(17)) // C(0 :7 ) = zmm17 + VMOVUPD(MEM(RCX, R13, 1, 64 ), ZMM(18)) // C(7 :15) = zmm18 + VMOVUPD(MEM(RCX, R13, 1, 128), ZMM(19)) // C(16:23) = zmm19 + + // row4 + VMULPD(ZMM(20), ZMM(20), ZMM(0)) // zmm20 *= alpha + VMULPD(ZMM(21), ZMM(21), ZMM(0)) // zmm21 *= alpha + VMULPD(ZMM(22), ZMM(22), ZMM(0)) // zmm22 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(20), ZMM(1), MEM(RCX, R10, 4 )) // zmm20 = zmm1*C(0 :7 ) + zmm20 + VFMADD231PD(ZMM(21), ZMM(1), MEM(RCX, R10, 4, 64 )) // zmm21 = zmm1*C(8 :15) + zmm21 + VFMADD231PD(ZMM(22), ZMM(1), MEM(RCX, R10, 4, 128)) // zmm22 = zmm1*C(16:23) + zmm22 + /*store c*/ + VMOVUPD(MEM(RCX, R10, 4 ), ZMM(20)) // C(0 :7 ) = zmm20 + VMOVUPD(MEM(RCX, R10, 4, 64 ), ZMM(21)) // C(7 :15) = zmm21 + VMOVUPD(MEM(RCX, R10, 4, 128), ZMM(22)) // C(16:23) = zmm22 + + // row5 + VMULPD(ZMM(23), ZMM(23), ZMM(0)) // zmm23 *= alpha + VMULPD(ZMM(24), ZMM(24), ZMM(0)) // zmm24 *= alpha + VMULPD(ZMM(25), ZMM(25), ZMM(0)) // zmm25 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(23), ZMM(1), MEM(RCX, RDX, 1 )) // zmm23 = zmm1*C(0 :7 ) + zmm23 + VFMADD231PD(ZMM(24), ZMM(1), MEM(RCX, RDX, 1, 64 )) // zmm24 = zmm1*C(8 :15) + zmm24 + VFMADD231PD(ZMM(25), ZMM(1), MEM(RCX, RDX, 1, 128)) // zmm25 = zmm1*C(16:23) + zmm25 + /*store c*/ + VMOVUPD(MEM(RCX, RDX, 1 ), ZMM(23)) // C(0 :7 ) = zmm23 + VMOVUPD(MEM(RCX, RDX, 1, 64 ), ZMM(24)) // C(7 :15) = zmm24 + VMOVUPD(MEM(RCX, RDX, 1, 128), ZMM(25)) // C(16:23) = zmm25 + + // row6 + VMULPD(ZMM(26), ZMM(26), ZMM(0)) // zmm26 *= alpha + VMULPD(ZMM(27), ZMM(27), ZMM(0)) // zmm27 *= alpha + VMULPD(ZMM(28), ZMM(28), ZMM(0)) // zmm28 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(26), ZMM(1), MEM(RCX, R13, 2 )) // zmm26 = zmm1*C(0 :7 ) + zmm26 + VFMADD231PD(ZMM(27), ZMM(1), MEM(RCX, R13, 2, 64 )) // zmm27 = zmm1*C(8 :15) + zmm27 + VFMADD231PD(ZMM(28), ZMM(1), MEM(RCX, R13, 2, 128)) // zmm28 = zmm1*C(16:23) + zmm28 + /*store c*/ + VMOVUPD(MEM(RCX, R13, 2 ), ZMM(26)) // C(0 :7 ) = zmm26 + VMOVUPD(MEM(RCX, R13, 2, 64 ), ZMM(27)) // C(7 :15) = zmm27 + VMOVUPD(MEM(RCX, R13, 2, 128), ZMM(28)) // C(16:23) = zmm28 + + // row6 + VMULPD(ZMM(29), ZMM(29), ZMM(0)) // zmm29 *= alpha + VMULPD(ZMM(30), ZMM(30), ZMM(0)) // zmm20 *= alpha + VMULPD(ZMM(31), ZMM(31), ZMM(0)) // zmm31 *= alpha + /*scale by beta*/ + VFMADD231PD(ZMM(29), ZMM(1), MEM(RCX, R14, 1 )) // zmm29 = zmm1*C(0 :7 ) + zmm29 + VFMADD231PD(ZMM(30), ZMM(1), MEM(RCX, R14, 1, 64 )) // zmm30 = zmm1*C(8 :15) + zmm30 + VFMADD231PD(ZMM(31), ZMM(1), MEM(RCX, R14, 1, 128)) // zmm31 = zmm1*C(16:23) + zmm31 + /*store c*/ + VMOVUPD(MEM(RCX, R14, 1 ), ZMM(29)) // C(0 :7 ) = zmm29 + VMOVUPD(MEM(RCX, R14, 1, 64 ), ZMM(30)) // C(7 :15) = zmm30 + VMOVUPD(MEM(RCX, R14, 1, 128), ZMM(31)) // C(16:23) = zmm31 +#ifdef ENABLE_COL_GEN_STORE + JMP(END) + + LABEL(COLUPDATE) + // if C is col major stored + // R12 = cs_c + VBROADCASTSD(ZMM(1), MEM(RBX)) // broadcast beta to zmm1 + + LEA(R13, MEM(R12, R12, 2)) // cs_c*3 -> cs_c + cs_c*2 + LEA(RDX, MEM(R12, R12, 4)) // cs_c*5 -> cs_c + cs_c*4 + LEA(R14, MEM(R12, R13, 2)) // cs_c*7 -> cs_c + cs_c*3*2 + + VCOMISD(XMM(1), XMM(2)) + JE(COLSTORBZ) // jump is beta == 0 + // beta != 0 + + /* + * // registers pre tranpose + * _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + * | zmm8 | zmm9 | zmm10 | + * | zmm11 | zmm12 | zmm13 | + * | zmm14 | zmm15 | zmm16 | + * | zmm17 | zmm18 | zmm19 | + * | zmm20 | zmm21 | zmm22 | + * | zmm23 | zmm24 | zmm25 | + * | zmm26 | zmm27 | zmm28 | + * | zmm29 | zmm30 | zmm31 | + * _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + * + * + * // registers post transpose + * __________________________ + * | z z z z z z z z | + * | m m m m m m m m | + * | m m m m m m m m | + * | 8 1 1 1 2 2 2 2 | + * | 1 4 7 0 3 6 9 | + * | ________________________| + * | z z z z z z z z | + * | m m m m m m m m | + * | m m m m m m m m | + * | 9 1 1 1 2 2 2 3 | + * | 2 5 8 1 4 7 0 | + * | ________________________| + * | z z z z z z z z | + * | m m m m m m m m | + * | m m m m m m m m | + * | 1 1 1 1 2 2 2 3 | + * | 0 3 6 9 2 5 8 1 | + * | ________________________| + */ + + + TRANSPOSE_8X8( 8, 11, 14, 17, 20, 23, 26, 29) // registers + TRANSPOSE_8X8( 9, 12, 15, 18, 21, 24, 27, 30) + TRANSPOSE_8X8(10, 13, 16, 19, 22, 25, 28, 31) + VBROADCASTSD(ZMM(1), MEM(RBX)) // broadcast beta to zmm1 + VBROADCASTSD(ZMM(0), MEM(RAX)) // broadcast alpha into zmm0 + + UPDATE_C_COL_STORE( 8, 11, 14, 17, 20, 23, 26, 29) // scale by beta and store + UPDATE_C_COL_STORE( 9, 12, 15, 18, 21, 24, 27, 30) + UPDATE_C_COL_STORE(10, 13, 16, 19, 22, 25, 28, 31) + JMP(END) + + LABEL(COLSTORBZ) + // beta == 0 + + TRANSPOSE_8X8( 8, 11, 14, 17, 20, 23, 26, 29) + TRANSPOSE_8X8( 9, 12, 15, 18, 21, 24, 27, 30) + TRANSPOSE_8X8(10, 13, 16, 19, 22, 25, 28, 31) + VBROADCASTSD(ZMM(0), MEM(RAX)) // broadcast alpha into zmm0 + + UPDATE_C_COL_STORE_BZ( 8, 11, 14, 17, 20, 23, 26, 29) + UPDATE_C_COL_STORE_BZ( 9, 12, 15, 18, 21, 24, 27, 30) + UPDATE_C_COL_STORE_BZ(10, 13, 16, 19, 22, 25, 28, 31) + JMP(END) + + LABEL(SCATTERUPDATE) + // if C is general stride + VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) // scale all registers by alpha + VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) + VMULPD(ZMM(10), ZMM(10), ZMM(0)) + VMULPD(ZMM(11), ZMM(11), ZMM(0)) + VMULPD(ZMM(12), ZMM(12), ZMM(0)) + VMULPD(ZMM(13), ZMM(13), ZMM(0)) + VMULPD(ZMM(14), ZMM(14), ZMM(0)) + VMULPD(ZMM(15), ZMM(15), ZMM(0)) + VMULPD(ZMM(16), ZMM(16), ZMM(0)) + VMULPD(ZMM(17), ZMM(17), ZMM(0)) + VMULPD(ZMM(18), ZMM(18), ZMM(0)) + VMULPD(ZMM(19), ZMM(19), ZMM(0)) + VMULPD(ZMM(20), ZMM(20), ZMM(0)) + VMULPD(ZMM(21), ZMM(21), ZMM(0)) + VMULPD(ZMM(22), ZMM(22), ZMM(0)) + VMULPD(ZMM(23), ZMM(23), ZMM(0)) + VMULPD(ZMM(24), ZMM(24), ZMM(0)) + VMULPD(ZMM(25), ZMM(25), ZMM(0)) + VMULPD(ZMM(26), ZMM(26), ZMM(0)) + VMULPD(ZMM(27), ZMM(27), ZMM(0)) + VMULPD(ZMM(28), ZMM(28), ZMM(0)) + VMULPD(ZMM(29), ZMM(29), ZMM(0)) + VMULPD(ZMM(30), ZMM(30), ZMM(0)) + VMULPD(ZMM(31), ZMM(31), ZMM(0)) + + MOV(R13, VAR(offsetPtr)) // load pointer to the array containing + // offsets for scatter/gather + VPBROADCASTQ(ZMM(0), R12) // broadcast cs_c to zmm0 + VPMULLQ(ZMM(2), ZMM(0), MEM(R13)) // scale offsets array with cs_c + VPMULLQ(ZMM(3), ZMM(0), MEM(R13, 8*8)) + VPMULLQ(ZMM(4), ZMM(0), MEM(R13,16*8)) + VBROADCASTSD(ZMM(1), MEM(RBX)) // broadcast beta to zmm1 + + VCOMISD(XMM(1), XMM(2)) + JE(GENSTORBZ) // if beta == 0 jump + UPDATE_C_SCATTERED( 8, 9, 10) // scale by beta and store + UPDATE_C_SCATTERED(11, 12, 13) + UPDATE_C_SCATTERED(14, 15, 16) + UPDATE_C_SCATTERED(17, 18, 19) + UPDATE_C_SCATTERED(20, 21, 22) + UPDATE_C_SCATTERED(23, 24, 25) + UPDATE_C_SCATTERED(26, 27, 28) + UPDATE_C_SCATTERED(29, 30, 31) + JMP(END) + LABEL(GENSTORBZ) + UPDATE_C_SCATTERED_BZ( 8, 9, 10) + UPDATE_C_SCATTERED_BZ(11, 12, 13) + UPDATE_C_SCATTERED_BZ(14, 15, 16) + UPDATE_C_SCATTERED_BZ(17, 18, 19) + UPDATE_C_SCATTERED_BZ(20, 21, 22) + UPDATE_C_SCATTERED_BZ(23, 24, 25) + UPDATE_C_SCATTERED_BZ(26, 27, 28) + UPDATE_C_SCATTERED_BZ(29, 30, 31) +#endif + + LABEL(END) + + // VZEROUPPER() // slight improvement when K is small by removing vzeroupper + + END_ASM + ( + : // output operands + : // input operands + [k] "m" (k), + [a] "m" (a), + [b] "m" (b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [offsetPtr] "m" (offsetPtr) + : // register clobber list + "rax", "rbx", "rcx", "r10", "r12", "r13", "r14", + "k0", "k1", "k2", "k3", "xmm1", "xmm2", + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", + "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", + "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" + ) +} + +/* C += A*B */ +#define UPDATE_C_BETA_1(R1, R2, R3) \ + VADDPD(ZMM(R1), ZMM(R1), MEM(RCX)) /* C += A*B */ \ + VADDPD(ZMM(R2), ZMM(R2), MEM(RCX, 64)) \ + VADDPD(ZMM(R3), ZMM(R3), MEM(RCX, 128)) \ + VMOVUPD(MEM(RCX ), ZMM(R1)) \ + VXORPD(ZMM(R1), ZMM(R1), ZMM(R1)) \ + VMOVUPD(MEM(RCX, 64), ZMM(R2)) \ + VXORPD(ZMM(R2), ZMM(R2), ZMM(R2)) \ + VMOVUPD(MEM(RCX, 128), ZMM(R3)) \ + VXORPD(ZMM(R3), ZMM(R3), ZMM(R3)) \ + LEA(RCX, MEM(RCX, R10, 1)) \ + +/* C = A*B - C */ +#define UPDATE_C_BETA_M1(R1, R2, R3) \ + VSUBPD(ZMM(R1), ZMM(R1), MEM(RCX)) \ + VSUBPD(ZMM(R2), ZMM(R2), MEM(RCX, 64)) \ + VSUBPD(ZMM(R3), ZMM(R3), MEM(RCX, 128)) \ + VMOVUPD(MEM(RCX ), ZMM(R1)) \ + VXORPD(ZMM(R1), ZMM(R1), ZMM(R1)) \ + VMOVUPD(MEM(RCX, 64), ZMM(R2)) \ + VXORPD(ZMM(R2), ZMM(R2), ZMM(R2)) \ + VMOVUPD(MEM(RCX, 128), ZMM(R3)) \ + VXORPD(ZMM(R3), ZMM(R3), ZMM(R3)) \ + LEA(RCX, MEM(RCX, R10, 1)) \ + +/* C = A*B */ +#define UPDATE_C_BETA_0(R1, R2, R3) \ + VMOVUPD(MEM(RCX ), ZMM(R1)) \ + VXORPD(ZMM(R1), ZMM(R1), ZMM(R1)) \ + VMOVUPD(MEM(RCX, 64), ZMM(R2)) \ + VXORPD(ZMM(R2), ZMM(R2), ZMM(R2)) \ + VMOVUPD(MEM(RCX, 128), ZMM(R3)) \ + VXORPD(ZMM(R3), ZMM(R3), ZMM(R3)) \ + LEA(RCX, MEM(RCX, R10, 1)) \ + +/* C = (beta*c) + (A*B) */ +#define UPDATE_C_BETA_N(R1, R2, R3) \ + VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX)) \ + VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX,64)) \ + VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX,128)) \ + \ + VMOVUPD(MEM(RCX ), ZMM(R1)) \ + VXORPD(ZMM(R1), ZMM(R1), ZMM(R1)) \ + VMOVUPD(MEM(RCX, 64), ZMM(R2)) \ + VXORPD(ZMM(R2), ZMM(R2), ZMM(R2)) \ + VMOVUPD(MEM(RCX,128), ZMM(R3)) \ + VXORPD(ZMM(R3), ZMM(R3), ZMM(R3)) \ + LEA(RCX, MEM(RCX, R10, 1)) \ + +#define PRE_K_LOOP() \ + const int64_t n = n0; \ + const int64_t m = m0; \ + const int64_t k = k0; \ + const int64_t ldc = ldc0; \ + BEGIN_ASM() \ + \ + MOV(RDI, VAR(n)) /* load N into RDI */ \ + MOV(RSI, VAR(m)) /* load M into RSI */ \ + MOV(RDX, VAR(k)) /* load K into RDX */ \ + MOV(RCX, VAR(c)) /* load C macro panel pointer into RCX*/ \ + MOV(R8 , VAR(a)) /* load A macro panel pointer into R8 */ \ + MOV(R9 , VAR(b)) /* load B macro panel pointer into R9 */ \ + MOV(R10, VAR(ldc)) /* load ldc into R10*/ \ + \ + SAL(R10, IMM(3)) /* ldc *= 8 */ \ + SAR(RSI, IMM(3)) /* m_iter = M/8 */ \ + \ + ZERO_REGISTERS() /* zero accumulation registers */ \ + \ + MOV(VAR(m), RSI) /* backup m_iter into stack */ \ + MOV(R15, R8) /* backup A macro panel pointer to R15 */ \ + MOV(RBP, RCX) /* backup C macro panel pointer to RBP */ \ + \ + CMP(RDI, IMM(0)) /* check if m_iter is zero */ \ + JLE(ENDJR) /* JMP to endjr if m_iter <= 0*/ \ + \ + LOOP_ALIGN \ + LABEL(LOOPJR) /* JR loop */ \ + \ + MOV(R8, R15) /* restore A macro panel pointer */ \ + MOV(RSI, VAR(m)) /* copy m_iter to RSI */ \ + MOV(RCX, RBP) /* restore pointer to C macro panel pointer */\ + TEST(RSI, RSI) \ + \ + JZ(ENDIR) /* Jump to ENDIR if m_iter(RSI) == 0*/ \ + LOOP_ALIGN \ + LABEL(LOOPIR) \ + MOV(RAX, R8) /* Move A micro panel pointer to RAX */ \ + MOV(RBX, R9) /* Move B micro panel pointer to RBX */ \ + LEA(R12, MEM(RCX, 63)) /* calculate c_prefetch pointer */ + +#define POST_K_LOOP() \ + LABEL(END_MICRO_KER) \ + \ + MOV(R13, RDX) /* move k_iter into R13 */ \ + IMUL(R13, IMM(8)) /* k_iter *= 8 */ \ + LEA(R8, MEM(R8, R13, 8)) /* a_next_upanel = A + (k*8) */ \ + \ + DEC(RSI) /* decrement m_iter */ \ + JNZ(LOOPIR) \ + \ + LABEL(ENDIR) \ + \ + MOV(R14, RDX) /* move k_iter into R14 */ \ + IMUL(R14, IMM(24)) /* k_iter *= 24 */ \ + LEA(R9, MEM(R9, R14, 8)) /* b_next_upanel = B + (k*24) */ \ + LEA(RBP, MEM(RBP, 24*8)) /* c_next_upanel = C + (24*8) */ \ + SUB(RDI, IMM(24)) /* subtract NR(24) from N */ \ + JNZ(LOOPJR) \ + \ + LABEL(ENDJR) \ + \ + END_ASM \ + ( \ + :: \ + [n] "m" (n), \ + [m] "m" (m), \ + [k] "m" (k), \ + [c] "m" (c), \ + [a] "m" (a), \ + [b] "m" (b), \ + [beta] "m" (beta), \ + [ldc] "m" (ldc) \ + : \ + "rax", "rbp", "rbx", "rcx", "rdi", "rsi", "r8", "r9", \ + "r10", "r12", "r13", "r14", "r15", "xmm1", "xmm2",\ + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", \ + "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13",\ + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", \ + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", \ + "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory"\ + ) \ + + +/* + Macro kernel for C = A*B (beta = 0) + Only Row major stored C is supported. +*/ +BLIS_INLINE void bli_dgemm_avx512_asm_8x24_macro_kernel_b0 +( + dim_t n0, + dim_t m0, + dim_t k0, + double* c, + double* a, + double* b, + dim_t ldc0, + double* beta +) +{ + PRE_K_LOOP() + K_LOOP() + LABEL(POSTACCUM) + UPDATE_C_BETA_0( 8, 9, 10) + UPDATE_C_BETA_0(11, 12, 13) + UPDATE_C_BETA_0(14, 15, 16) + UPDATE_C_BETA_0(17, 18, 19) + UPDATE_C_BETA_0(20, 21, 22) + UPDATE_C_BETA_0(23, 24, 25) + UPDATE_C_BETA_0(26, 27, 28) + UPDATE_C_BETA_0(29, 30, 31) + POST_K_LOOP() + +} + +/* + Macro kernel for C = C + (A*B) (beta = 1) + Only Row major stored C is supported. +*/ +BLIS_INLINE void bli_dgemm_avx512_asm_8x24_macro_kernel_b1 +( + dim_t n0, + dim_t m0, + dim_t k0, + double* c, + double* a, + double* b, + dim_t ldc0, + double* beta +) +{ + PRE_K_LOOP() + K_LOOP() + LABEL(POSTACCUM) + UPDATE_C_BETA_1( 8, 9, 10) + UPDATE_C_BETA_1(11, 12, 13) + UPDATE_C_BETA_1(14, 15, 16) + UPDATE_C_BETA_1(17, 18, 19) + UPDATE_C_BETA_1(20, 21, 22) + UPDATE_C_BETA_1(23, 24, 25) + UPDATE_C_BETA_1(26, 27, 28) + UPDATE_C_BETA_1(29, 30, 31) + POST_K_LOOP() + +} + +/* + Macro kernel for C = (A*B) - C (beta = 1) + Only Row major stored C is supported. +*/ +BLIS_INLINE void bli_dgemm_avx512_asm_8x24_macro_kernel_bm1 +( + dim_t n0, + dim_t m0, + dim_t k0, + double* c, + double* a, + double* b, + dim_t ldc0, + double* beta +) +{ + PRE_K_LOOP() + K_LOOP() + LABEL(POSTACCUM) + MOV(RBX, VAR(beta)) + VBROADCASTSD(ZMM(1), MEM(RBX)) + UPDATE_C_BETA_M1( 8, 9, 10) + UPDATE_C_BETA_M1(11, 12, 13) + UPDATE_C_BETA_M1(14, 15, 16) + UPDATE_C_BETA_M1(17, 18, 19) + UPDATE_C_BETA_M1(20, 21, 22) + UPDATE_C_BETA_M1(23, 24, 25) + UPDATE_C_BETA_M1(26, 27, 28) + UPDATE_C_BETA_M1(29, 30, 31) + POST_K_LOOP() + +} + +/* + Macro kernel for C = (beta*C) + (A*B) + Only Row major stored C is supported. +*/ +BLIS_INLINE void bli_dgemm_avx512_asm_8x24_macro_kernel_bn +( + dim_t n0, + dim_t m0, + dim_t k0, + double* c, + double* a, + double* b, + dim_t ldc0, + double* beta +) +{ + PRE_K_LOOP() + K_LOOP() + LABEL(POSTACCUM) + MOV(RBX, VAR(beta)) + VBROADCASTSD(ZMM(1), MEM(RBX)) + UPDATE_C_BETA_N( 8, 9, 10) + UPDATE_C_BETA_N(11, 12, 13) + UPDATE_C_BETA_N(14, 15, 16) + UPDATE_C_BETA_N(17, 18, 19) + UPDATE_C_BETA_N(20, 21, 22) + UPDATE_C_BETA_N(23, 24, 25) + UPDATE_C_BETA_N(26, 27, 28) + UPDATE_C_BETA_N(29, 30, 31) + POST_K_LOOP() + +} + +/* + DGEMM 8x24 Macro kernel + MR = 8, NR = 24 + Only row major stored C is supported by this kernel. + Alpha scaling is not supported. +*/ +void bli_dgemm_avx512_asm_8x24_macro_kernel +( + dim_t n, + dim_t m, + dim_t k, + double* c, + double* a, + double* b, + dim_t ldc, + double* beta +) +{ + if(*(double*)beta == 1) + { + bli_dgemm_avx512_asm_8x24_macro_kernel_b1 + ( + n, m, k, c, a, b, ldc, beta + ); + } + else if(*(double*)beta == -1) + { + bli_dgemm_avx512_asm_8x24_macro_kernel_bm1 + ( + n, m, k, c, a, b, ldc, beta + ); + } + else if (*(double*)beta == 0) + { + bli_dgemm_avx512_asm_8x24_macro_kernel_b0 + ( + n, m, k, c, a, b, ldc, beta + ); + } + else + { + bli_dgemm_avx512_asm_8x24_macro_kernel_bn + ( + n, m, k, c, a, b, ldc, beta + ); + } +} diff --git a/kernels/zen5/3/sup/bli_dgemmsup_rv_zen5_asm_24x8m.c b/kernels/zen5/3/sup/bli_dgemmsup_rv_zen5_asm_24x8m.c new file mode 100644 index 0000000000..27e07e8998 --- /dev/null +++ b/kernels/zen5/3/sup/bli_dgemmsup_rv_zen5_asm_24x8m.c @@ -0,0 +1,9795 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" +#define TAIL_NITER 3 + +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1 + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) )\ +\ + vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \ + vmovupd( zmm5, (rcx, r13, 1) )\ +\ + vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \ + vmovupd( zmm3, (rcx, r12, 2) )\ +\ + vfmadd231pd( mem(rcx, rdx, 1),zmm31,zmm8 ) \ + vmovupd( zmm8, (rcx, rdx, 1) )\ + add(r14, rcx) + + +/** + * stores FMA result to C. +*/ +#define UPDATE_C_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) \ +\ + vmovupd( zmm5, (rcx, r13, 1) ) \ +\ + vmovupd( zmm3, (rcx, r12, 2) ) \ +\ + vmovupd( zmm8, (rcx, rdx, 1) ) \ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it + * Stores back the C row. +*/ +#define UPDATE_MASKED_C \ +\ + vmovupd( mem(rcx), zmm30 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_KZ(2) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(2))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(2)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(2)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(2)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(2)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(2)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(2)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(2)))\ + add(r14, rcx) + +/** + * mask register is set, stores FMA result to C. +*/ +#define UPDATE_MASKED_C_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(2))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(2))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(2)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(2)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(2))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(2))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(2))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(2))) \ + add(r14, rcx) + +/* These kernels Assume that A matrix needs to be in col-major order + * B matrix can be col/row-major + * C matrix can be col/row-major + * Prefetch for C is done assuming that C is col-stored. + * Prefetch of B is done assuming that the matrix is col-stored. + * Prefetch for B and C matrices when row-stored is yet to be added. + * Prefetch of A matrix is not done in edge-case kernels. + */ + +void bli_dgemmsup_rv_zen5_asm_24x8m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // n0 is actually n_left which is calculated at JR loop. + uint64_t n_left = (uint64_t)n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other nx? kernels, as needed + if( n_left ) + { + dgemmsup_ker_ft ker_fps[8] = + { + NULL, + bli_dgemmsup_rv_zen5_asm_24x1m, + bli_dgemmsup_rv_zen5_asm_24x2m, + bli_dgemmsup_rv_zen5_asm_24x3m, + bli_dgemmsup_rv_zen5_asm_24x4m, + bli_dgemmsup_rv_zen5_asm_24x5m, + bli_dgemmsup_rv_zen5_asm_24x6m, + bli_dgemmsup_rv_zen5_asm_24x7m, + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ n_left ]; + + ker_fp + ( + conja, conjb, m0, n_left, k0, + alpha, abuf, rs_a0, cs_a0, bbuf, rs_b0, cs_b0, + beta, cbuf, rs_c0, cs_c0, data, cntx + ); + + return; + } + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(r9, r9, 2 ), r13) // r13 = 3*cs_b + // if n > 4, a second pointer(r12) which points to rbx + 4*cs_b + //is also used to traverse B matrix + lea(mem(rbx, r9, 4), r12) // r12 = rbx + 4*cs_b + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + // if n > 4, a second pointer which point to r11 + 4*cs_b + //is also used to prefetch from B matrix + lea(mem(r11, r9, 4), r15) // r15 = r11 + 4* cs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm25, zmm25, zmm25) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm22, zmm22, zmm22) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21,zmm21, zmm21) + vxorpd(zmm23, zmm23, zmm23) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 8+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(8), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) + vfmadd231pd( zmm5,zmm31,zmm23 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + vbroadcastsd( mem(r12,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm20 ) + vfmadd231pd( zmm1,zmm31,zmm21 ) + vfmadd231pd( zmm2,zmm31,zmm23 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + vmulpd( zmm30,zmm12,zmm12 ) + vmulpd( zmm30,zmm13,zmm13 ) + vmulpd( zmm30,zmm27,zmm27 ) + vmulpd( zmm30,zmm14,zmm14 ) + vmulpd( zmm30,zmm15,zmm15 ) + vmulpd( zmm30,zmm24,zmm24 ) + vmulpd( zmm30,zmm16,zmm16 ) + vmulpd( zmm30,zmm17,zmm17 ) + vmulpd( zmm30,zmm25,zmm25 ) + vmulpd( zmm30,zmm18,zmm18 ) + vmulpd( zmm30,zmm19,zmm19 ) + vmulpd( zmm30,zmm22,zmm22 ) + vmulpd( zmm30,zmm20,zmm20 ) + vmulpd( zmm30,zmm21,zmm21 ) + vmulpd( zmm30,zmm23,zmm23 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + lea(mem(rcx, rdi, 4), rdx) // rdx = rcx + 4 * cs_c + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vfmadd231pd( mem(rcx,r13,1),zmm31,zmm12) + vmovupd( zmm12,(rcx,r13,1)) + vfmadd231pd( 0x40(rcx,r13,1),zmm31,zmm13) + vmovupd( zmm13,0x40(rcx,r13,1)) + vfmadd231pd( 0x80(rcx,r13,1),zmm31,zmm27) + vmovupd( zmm27,0x80(rcx,r13,1)) + vfmadd231pd( mem(rdx),zmm31,zmm14) + vmovupd( zmm14,(rdx)) + vfmadd231pd( 0x40(rdx),zmm31,zmm15) + vmovupd( zmm15,0x40(rdx)) + vfmadd231pd( 0x80(rdx),zmm31,zmm24) + vmovupd( zmm24,0x80(rdx)) + vfmadd231pd( mem(rdx,rdi,1),zmm31,zmm16) + vmovupd( zmm16,(rdx,rdi,1)) + vfmadd231pd( 0x40(rdx,rdi,1),zmm31,zmm17) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vfmadd231pd( 0x80(rdx,rdi,1),zmm31,zmm25) + vmovupd( zmm25,0x80(rdx,rdi,1)) + vfmadd231pd( mem(rdx,rdi,2),zmm31,zmm18) + vmovupd( zmm18,(rdx,rdi,2)) + vfmadd231pd( 0x40(rdx,rdi,2),zmm31,zmm19) + vmovupd( zmm19,0x40(rdx,rdi,2)) + vfmadd231pd( 0x80(rdx,rdi,2),zmm31,zmm22) + vmovupd( zmm22,0x80(rdx,rdi,2)) + vfmadd231pd( mem(rdx,r13,1),zmm31,zmm20) + vmovupd( zmm20,(rdx,r13,1)) + vfmadd231pd( 0x40(rdx,r13,1),zmm31,zmm21) + vmovupd( zmm21,0x40(rdx,r13,1)) + vfmadd231pd( 0x80(rdx,r13,1),zmm31,zmm23) + vmovupd( zmm23,0x80(rdx,r13,1)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_C + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C + //Third 8x8 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vmovupd( zmm12,(rcx,r13,1)) + vmovupd( zmm13,0x40(rcx,r13,1)) + vmovupd( zmm27,0x80(rcx,r13,1)) + vmovupd( zmm14,(rdx)) + vmovupd( zmm15,0x40(rdx)) + vmovupd( zmm24,0x80(rdx)) + vmovupd( zmm16,(rdx,rdi,1)) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vmovupd( zmm25,0x80(rdx,rdi,1)) + vmovupd( zmm18,(rdx,rdi,2)) + vmovupd( zmm19,0x40(rdx,rdi,2)) + vmovupd( zmm22,0x80(rdx,rdi,2)) + vmovupd( zmm20,(rdx,r13,1)) + vmovupd( zmm21,0x40(rdx,r13,1)) + vmovupd( zmm23,0x80(rdx,r13,1)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_C_BZ + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_BZ + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_BZ + //Third 8x8 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x8( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x8( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x8( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x7m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(r9, r9, 2 ), r13) // r13 = 3*cs_b + // if n > 4, a second pointer(r12) which points to rbx + 4*cs_b + //is also used to traverse B matrix + lea(mem(rbx, r9, 4), r12) // r12 = rbx + 4*cs_b + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + // if n > 4, a second pointer which point to r11 + 4*cs_b + //is also used to prefetch from B matrix + lea(mem(r11, r9, 4), r15) // r15 = r11 + 4* cs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm25, zmm25, zmm25) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm22, zmm22, zmm22) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 7+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(7), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + vbroadcastsd( mem(r12,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm18 ) + vfmadd231pd( zmm1,zmm30,zmm19 ) + vfmadd231pd( zmm2,zmm30,zmm22 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + vmulpd( zmm30,zmm12,zmm12 ) + vmulpd( zmm30,zmm13,zmm13 ) + vmulpd( zmm30,zmm27,zmm27 ) + vmulpd( zmm30,zmm14,zmm14 ) + vmulpd( zmm30,zmm15,zmm15 ) + vmulpd( zmm30,zmm24,zmm24 ) + vmulpd( zmm30,zmm16,zmm16 ) + vmulpd( zmm30,zmm17,zmm17 ) + vmulpd( zmm30,zmm25,zmm25 ) + vmulpd( zmm30,zmm18,zmm18 ) + vmulpd( zmm30,zmm19,zmm19 ) + vmulpd( zmm30,zmm22,zmm22 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + lea(mem(rcx, rdi, 4), rdx) // rdx = rcx + 4 * cs_c + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vfmadd231pd( mem(rcx,r13,1),zmm31,zmm12) + vmovupd( zmm12,(rcx,r13,1)) + vfmadd231pd( 0x40(rcx,r13,1),zmm31,zmm13) + vmovupd( zmm13,0x40(rcx,r13,1)) + vfmadd231pd( 0x80(rcx,r13,1),zmm31,zmm27) + vmovupd( zmm27,0x80(rcx,r13,1)) + vfmadd231pd( mem(rdx),zmm31,zmm14) + vmovupd( zmm14,(rdx)) + vfmadd231pd( 0x40(rdx),zmm31,zmm15) + vmovupd( zmm15,0x40(rdx)) + vfmadd231pd( 0x80(rdx),zmm31,zmm24) + vmovupd( zmm24,0x80(rdx)) + vfmadd231pd( mem(rdx,rdi,1),zmm31,zmm16) + vmovupd( zmm16,(rdx,rdi,1)) + vfmadd231pd( 0x40(rdx,rdi,1),zmm31,zmm17) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vfmadd231pd( 0x80(rdx,rdi,1),zmm31,zmm25) + vmovupd( zmm25,0x80(rdx,rdi,1)) + vfmadd231pd( mem(rdx,rdi,2),zmm31,zmm18) + vmovupd( zmm18,(rdx,rdi,2)) + vfmadd231pd( 0x40(rdx,rdi,2),zmm31,zmm19) + vmovupd( zmm19,0x40(rdx,rdi,2)) + vfmadd231pd( 0x80(rdx,rdi,2),zmm31,zmm22) + vmovupd( zmm22,0x80(rdx,rdi,2)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x7 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vmovupd( zmm12,(rcx,r13,1)) + vmovupd( zmm13,0x40(rcx,r13,1)) + vmovupd( zmm27,0x80(rcx,r13,1)) + vmovupd( zmm14,(rdx)) + vmovupd( zmm15,0x40(rdx)) + vmovupd( zmm24,0x80(rdx)) + vmovupd( zmm16,(rdx,rdi,1)) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vmovupd( zmm25,0x80(rdx,rdi,1)) + vmovupd( zmm18,(rdx,rdi,2)) + vmovupd( zmm19,0x40(rdx,rdi,2)) + vmovupd( zmm22,0x80(rdx,rdi,2)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x7 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 7; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x7( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x7( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x7( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x6m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(r9, r9, 2 ), r13) // r13 = 3*cs_b + // if n > 4, a second pointer(r12) which points to rbx + 4*cs_b + //is also used to traverse B matrix + lea(mem(rbx, r9, 4), r12) // r12 = rbx + 4*cs_b + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + // if n > 4, a second pointer which point to r11 + 4*cs_b + //is also used to prefetch from B matrix + lea(mem(r11, r9, 4), r15) // r15 = r11 + 4* cs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm25, zmm25, zmm25) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 6+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(6), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + vbroadcastsd( mem(r12,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm31,zmm16 ) + vfmadd231pd( zmm1,zmm31,zmm17 ) + vfmadd231pd( zmm2,zmm31,zmm25 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + vmulpd( zmm30,zmm12,zmm12 ) + vmulpd( zmm30,zmm13,zmm13 ) + vmulpd( zmm30,zmm27,zmm27 ) + vmulpd( zmm30,zmm14,zmm14 ) + vmulpd( zmm30,zmm15,zmm15 ) + vmulpd( zmm30,zmm24,zmm24 ) + vmulpd( zmm30,zmm16,zmm16 ) + vmulpd( zmm30,zmm17,zmm17 ) + vmulpd( zmm30,zmm25,zmm25 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + lea(mem(rcx, rdi, 4), rdx) // rdx = rcx + 4 * cs_c + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vfmadd231pd( mem(rcx,r13,1),zmm31,zmm12) + vmovupd( zmm12,(rcx,r13,1)) + vfmadd231pd( 0x40(rcx,r13,1),zmm31,zmm13) + vmovupd( zmm13,0x40(rcx,r13,1)) + vfmadd231pd( 0x80(rcx,r13,1),zmm31,zmm27) + vmovupd( zmm27,0x80(rcx,r13,1)) + vfmadd231pd( mem(rdx),zmm31,zmm14) + vmovupd( zmm14,(rdx)) + vfmadd231pd( 0x40(rdx),zmm31,zmm15) + vmovupd( zmm15,0x40(rdx)) + vfmadd231pd( 0x80(rdx),zmm31,zmm24) + vmovupd( zmm24,0x80(rdx)) + vfmadd231pd( mem(rdx,rdi,1),zmm31,zmm16) + vmovupd( zmm16,(rdx,rdi,1)) + vfmadd231pd( 0x40(rdx,rdi,1),zmm31,zmm17) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vfmadd231pd( 0x80(rdx,rdi,1),zmm31,zmm25) + vmovupd( zmm25,0x80(rdx,rdi,1)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x6 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vmovupd( zmm12,(rcx,r13,1)) + vmovupd( zmm13,0x40(rcx,r13,1)) + vmovupd( zmm27,0x80(rcx,r13,1)) + vmovupd( zmm14,(rdx)) + vmovupd( zmm15,0x40(rdx)) + vmovupd( zmm24,0x80(rdx)) + vmovupd( zmm16,(rdx,rdi,1)) + vmovupd( zmm17,0x40(rdx,rdi,1)) + vmovupd( zmm25,0x80(rdx,rdi,1)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x6 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 6; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x6( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x6( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x6( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x5m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(r9, r9, 2 ), r13) // r13 = 3*cs_b + // if n > 4, a second pointer(r12) which points to rbx + 4*cs_b + //is also used to traverse B matrix + lea(mem(rbx, r9, 4), r12) // r12 = rbx + 4*cs_b + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + // if n > 4, a second pointer which point to r11 + 4*cs_b + //is also used to prefetch from B matrix + lea(mem(r11, r9, 4), r15) // r15 = r11 + 4* cs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm24, zmm24, zmm24) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 5+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(5), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r15) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm5,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm5,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm5,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm5,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm24 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + vbroadcastsd( mem(r12),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + add( r8,r12 ) // second pointer of b += rs_b + vfmadd231pd( zmm0,zmm30,zmm14 ) + vfmadd231pd( zmm1,zmm30,zmm15 ) + vfmadd231pd( zmm2,zmm30,zmm24 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + vmulpd( zmm30,zmm12,zmm12 ) + vmulpd( zmm30,zmm13,zmm13 ) + vmulpd( zmm30,zmm27,zmm27 ) + vmulpd( zmm30,zmm14,zmm14 ) + vmulpd( zmm30,zmm15,zmm15 ) + vmulpd( zmm30,zmm24,zmm24 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + lea(mem(rcx, rdi, 4), rdx) // rdx = rcx + 4 * cs_c + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vfmadd231pd( mem(rcx,r13,1),zmm31,zmm12) + vmovupd( zmm12,(rcx,r13,1)) + vfmadd231pd( 0x40(rcx,r13,1),zmm31,zmm13) + vmovupd( zmm13,0x40(rcx,r13,1)) + vfmadd231pd( 0x80(rcx,r13,1),zmm31,zmm27) + vmovupd( zmm27,0x80(rcx,r13,1)) + vfmadd231pd( mem(rdx),zmm31,zmm14) + vmovupd( zmm14,(rdx)) + vfmadd231pd( 0x40(rdx),zmm31,zmm15) + vmovupd( zmm15,0x40(rdx)) + vfmadd231pd( 0x80(rdx),zmm31,zmm24) + vmovupd( zmm24,0x80(rdx)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x5 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vmovupd( zmm12,(rcx,r13,1)) + vmovupd( zmm13,0x40(rcx,r13,1)) + vmovupd( zmm27,0x80(rcx,r13,1)) + vmovupd( zmm14,(rdx)) + vmovupd( zmm15,0x40(rdx)) + vmovupd( zmm24,0x80(rdx)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x5 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 5; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x5( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x5( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x5( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x4m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(r9, r9, 2 ), r13) // r13 = 3*cs_b + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) + vxorpd(zmm23, zmm23, zmm23) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm25, zmm25, zmm25) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 4+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + label(.LOOP1) + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26, zmm12, zmm13, zmm27 + * to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, zmm25 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26, zmm12, zmm13, zmm27. + */ + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(4), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r13,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm26, zmm26) + vaddpd(zmm23, zmm12, zmm12) + vaddpd(zmm24, zmm13, zmm13) + vaddpd(zmm25, zmm27, zmm27) + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + vbroadcastsd( mem(rbx,r13,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm12 ) + vfmadd231pd( zmm1,zmm31,zmm13 ) + vfmadd231pd( zmm2,zmm31,zmm27 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + vmulpd( zmm30,zmm12,zmm12 ) + vmulpd( zmm30,zmm13,zmm13 ) + vmulpd( zmm30,zmm27,zmm27 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vfmadd231pd( mem(rcx,r13,1),zmm31,zmm12) + vmovupd( zmm12,(rcx,r13,1)) + vfmadd231pd( 0x40(rcx,r13,1),zmm31,zmm13) + vmovupd( zmm13,0x40(rcx,r13,1)) + vfmadd231pd( 0x80(rcx,r13,1),zmm31,zmm27) + vmovupd( zmm27,0x80(rcx,r13,1)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x4 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x4 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + vmovupd( zmm12,(rcx,r13,1)) + vmovupd( zmm13,0x40(rcx,r13,1)) + vmovupd( zmm27,0x80(rcx,r13,1)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x5 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x4( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x4( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x4( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x3m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 3+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21, zmm22 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26. + */ + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(3), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,2) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm26, zmm26) + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + vbroadcastsd( mem(rbx,r9,2),zmm30 ) + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm10 ) + vfmadd231pd( zmm1,zmm30,zmm11 ) + vfmadd231pd( zmm2,zmm30,zmm26 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + vmulpd( zmm30,zmm10,zmm10 ) + vmulpd( zmm30,zmm11,zmm11 ) + vmulpd( zmm30,zmm26,zmm26 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vfmadd231pd( mem(rcx,rdi,2),zmm31,zmm10) + vmovupd( zmm10,(rcx,rdi,2)) + vfmadd231pd( 0x40(rcx,rdi,2),zmm31,zmm11) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vfmadd231pd( 0x80(rcx,rdi,2),zmm31,zmm26) + vmovupd( zmm26,0x80(rcx,rdi,2)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x3 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + vmovupd( zmm10,(rcx,rdi,2)) + vmovupd( zmm11,0x40(rcx,rdi,2)) + vmovupd( zmm26,0x80(rcx,rdi,2)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_BZ + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x3 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 3; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x3( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x3( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x3( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x2m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 2+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29. + */ + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(2), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11,r9,1) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + vbroadcastsd( mem(rbx,r9,1),zmm31 ) + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm31,zmm8 ) + vfmadd231pd( zmm1,zmm31,zmm9 ) + vfmadd231pd( zmm2,zmm31,zmm29 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + vmulpd( zmm30,zmm8,zmm8 ) + vmulpd( zmm30,zmm9,zmm9 ) + vmulpd( zmm30,zmm29,zmm29 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + vfmadd231pd( mem(rcx,rdi,1),zmm31,zmm8) + vmovupd( zmm8,(rcx,rdi,1)) + vfmadd231pd( 0x40(rcx,rdi,1),zmm31,zmm9) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vfmadd231pd( 0x80(rcx,rdi,1),zmm31,zmm29) + vmovupd( zmm29,0x80(rcx,rdi,1)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x2 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + vmovupd( zmm8,(rcx,rdi,1)) + vmovupd( zmm9,0x40(rcx,rdi,1)) + vmovupd( zmm29,0x80(rcx,rdi,1)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x2 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x2( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x2( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x2( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +void bli_dgemmsup_rv_zen5_asm_24x1m +( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + double *abuf = a; + double *bbuf = b; + double *cbuf = c; + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t m_iter = (uint64_t)m0 / 24; + uint64_t m_left = (uint64_t)m0 % 24; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + uint64_t k_iter = (uint64_t)k0 / 8; + uint64_t k_left = (uint64_t)k0 % 8; + + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + + if ( m_iter == 0 ) goto consider_edge_cases; + + /* For one iteration of this loop, a block of MRxNR is computed + * This loop moves along m-dimension of c matrix with steps of MR*rs_c. + */ + for(dim_t m=0; m < m_iter; m++) + { + + a = abuf + m * ps_a ; // Move to next MRXKC in MCXKC (where MC>=MR) + b = bbuf; //Same KCXNR is used across different MRXKC in MCXKC + c = cbuf + m * rs_c * 24; // Move to next MRxNR in MCxNR (where MC >= MR) + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register + mov(var(a), rax) // load address of a + mov(var(cs_a), r10) // load cs_a + mov(var(b), rbx) // load address of b + mov(var(rs_b), r8) // load rs_b + mov(var(cs_b), r9) // load cs_b + mov(var(c), rcx) // load address of c + mov(var(cs_c), rdi) // load cs_c + lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) + lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) + lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) + lea(mem(rcx, 7*8), rdx) // C for prefetching + mov(r10, r14) // col stride of A + lea(mem(rax, r14, 2, 7*8), r14) // r14 = rax + 2*cs_a(A for prefetching) + lea(mem(rbx, r8, 8, 7*8), r11) // r11 = rbx + 8*rs_b(B for prefetching) + + /* Register usage: zmm0-5 are used to load A matrix + * zmm6-29 are used for accumulation + * zmm30-31 are used for broadcasting B matrix + */ + + // zero out all accumulation registers + vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + // K is unrolled by 8 to facilitate prefetch of B + // Assuming B to be col-stored, for each iteration of K, + //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b + label(.DLOOPKITER) // main loop + mov(var(k_iter), rsi) // i = k_iter + sub(imm( 1+TAIL_NITER), rsi) // i -= NR + TAIL_NITER + jle(.PREFETCHLOOP) // jump if i <= 0 + + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + */ + + label(.LOOP1) + + // ---------------------------------- iteration 1 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 2 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 3 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 4 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 5 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 6 + + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 7 + + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 8 + + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP1) // iterate again if i != 0. + + label(.PREFETCHLOOP) + add(imm(1), rsi) // i += NR + jle(.TAILITER) // jump if i <= 0. + + label(.LOOP2) + + // ---------------------------------- iteration 1 + prefetchw0( mem(rdx)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 2 + prefetchw0( mem(rdx, 64)) // prefetch C + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 3 + prefetchw0( mem(rdx, 128)) // prefetch C + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + lea(mem(rdx, rdi, 1), rdx) // C += cs_c + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + sub(imm(1), rsi) // i -= 1 + jnz(.LOOP2) // iterate again if i != 0. + label(.TAILITER) + add(imm(TAIL_NITER), rsi) // i += TAIL_NITER + jle(.TAIL) // jump if i <= 0 + + label(.LOOP3) + + // ---------------------------------- iteration 1 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + prefetch( 0,mem(r11) ) // prefetch B + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 2 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 3 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 4 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 5 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 6 + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + + // ---------------------------------- iteration 7 + vmovupd( mem(rax),zmm3 ) // load A + vmovupd( 0x40(rax),zmm4 ) + vmovupd( 0x80(rax),zmm5 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + + // ---------------------------------- iteration 8 + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) + lea(mem(r11,r8,8), r11) // b_next += 8*rs_b + dec(rsi) // i -= 1 + jnz(.LOOP3) // iterate again if i != 0. + + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + + label(.TAIL) + mov(var(k_left), rsi) // i = k_left + test(rsi, rsi) // check i via logical AND + je(.DPOSTACCUM) // if i == 0, jump to post-accumulation + + label(.DLOOPKLEFT) // k_left loop + vmovupd( mem(rax),zmm0 ) // load A + vmovupd( 0x40(rax),zmm1 ) + vmovupd( 0x80(rax),zmm2 ) + add( r10,rax ) // a += cs_a + //prefetch 24 elements(3 cachelines) of the second next col in same panel of A + prefetch( 0,mem(r14) ) + prefetch( 0,0x40(r14) ) + prefetch( 0,0x80(r14) ) + add( r10,r14 ) // a_next += cs_a + vbroadcastsd( mem(rbx),zmm30 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm0,zmm30,zmm6 ) + vfmadd231pd( zmm1,zmm30,zmm7 ) + vfmadd231pd( zmm2,zmm30,zmm28 ) + dec(rsi) // i -= 1 + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + label(.DPOSTACCUM) + mov(var(alpha), rdx) // load address of alpha + vbroadcastsd(mem(rdx), zmm30) // broadcast alpha + mov(var(beta), rax) // load address of beta + vbroadcastsd(mem(rax), zmm31) // broadcast beta + + // scale by alpha + vmulpd( zmm30,zmm6,zmm6 ) + vmulpd( zmm30,zmm7,zmm7 ) + vmulpd( zmm30,zmm28,zmm28 ) + + + mov(var(rs_c), rsi) // load rs_c + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) + vxorpd(ymm2, ymm2, ymm2) + vucomisd(xmm2, xmm31) // set ZF if beta == 0 + je(.DBETAZERO) // if ZF == 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + + jz(.DROWSTORED) // jump to row storage case + + label(.DCOLSTORED) + vfmadd231pd( mem(rcx),zmm31,zmm6) + vmovupd( zmm6,(rcx)) + vfmadd231pd( 0x40(rcx),zmm31,zmm7) + vmovupd( zmm7,0x40(rcx)) + vfmadd231pd( 0x80(rcx),zmm31,zmm28) + vmovupd( zmm28,0x80(rcx)) + + jmp(.DDONE) // jump to end. + + label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x1 tile updated + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + cmp(imm(8), rdi) // set ZF if (8*cs_c) == 8 + + jz(.DROWSTORBZ) // jump to row storage case + label(.DCOLSTORBZ) + vmovupd( zmm6,(rcx)) + vmovupd( zmm7,0x40(rcx)) + vmovupd( zmm28,0x80(rcx)) + + jmp(.DDONE) // jump to end. + + + label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x1 tile updated + label(.DDONE) + + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask] "m" (mask) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm2", "xmm31", + "ymm2", + "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", + "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "k2", "memory" + ) + } //mloop + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if (m_left) + { + const dim_t nr_cur = 1; + const dim_t i_edge = m0 - ( dim_t )m_left; + double *restrict cij = cbuf + i_edge * rs_c; + double *restrict ai = abuf + m_iter * ps_a; + double *restrict bj = bbuf; + // covers the range 16 < m_left <= 24 by using masked load/store instructions + if( 16 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_24x1( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 8 < m_left <= 16 by using masked load/store instructions + else if( 8 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_16x1( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + // covers the range 0 < m_left <= 8 by using masked load/store instructions + else if( 0 < m_left ) + { + bli_dgemmsup_rv_zen4_asm_8x1( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} diff --git a/kernels/zen5/aocl_smart/bli_aocl_smart.c b/kernels/zen5/aocl_smart/bli_aocl_smart.c new file mode 100644 index 0000000000..b5166ce750 --- /dev/null +++ b/kernels/zen5/aocl_smart/bli_aocl_smart.c @@ -0,0 +1,98 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +/* This function determines if we need to take SUP or native path + for given matrix sizes for zen5 configuration. + * Returns TRUE if the dimensions fall under SUP range + * Returns FALSE if the dimensions fall under Native range +*/ +bool bli_cntx_gemmsup_thresh_is_met_zen5( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) +{ + num_t dt = bli_obj_dt( c ); + + if( dt == BLIS_DOUBLE ) + { + dim_t k = bli_obj_width_after_trans( a ); + dim_t m, n; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + if ( bli_cntx_l3_sup_ker_dislikes_storage_of( c, stor_id, cntx ) ) + { + m = bli_obj_width(c); + n = bli_obj_length(c); + } + else + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + } + // For skinny sizes where one/two dimensions are small + if((m < 1000) || (n < 1000)) return TRUE; + // // For all combinations in small sizes + if((m < 2200) && (n < 2200) && (k < 2200)) return TRUE; + return FALSE; + } + else if( dt == BLIS_DCOMPLEX ) + { + dim_t k = bli_obj_width_after_trans( a ); + dim_t m, n; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + if ( bli_cntx_l3_sup_ker_dislikes_storage_of( c, stor_id, cntx ) ) + { + m = bli_obj_width(c); + n = bli_obj_length(c); + } + else + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + } + // For skinny sizes where m and/or n is small + // The threshold for m is a single value, but for n, it is + // also based on the packing size of A, since the kernels are + // column preferential + if( ( m <= 84 ) || ( ( n <= 84 ) && ( ( m * k ) <= 983040 ) ) ) return TRUE; + + // For all combinations in small sizes + if( ( m <= 216 ) && ( n <= 216 ) && ( k <= 216 ) ) return TRUE; + return FALSE; + } + else + return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx ); +} diff --git a/kernels/zen5/bli_kernels_zen5.h b/kernels/zen5/bli_kernels_zen5.h new file mode 100644 index 0000000000..a1cea5b290 --- /dev/null +++ b/kernels/zen5/bli_kernels_zen5.h @@ -0,0 +1,67 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// native dgemm kernel +GEMM_UKR_PROT( double, d, gemm_avx512_asm_8x24 ) + +// Dgemm sup RV kernels +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x8m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x7m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x6m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x5m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x4m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x3m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x2m) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x1m) + +void bli_dgemm_avx512_asm_8x24_macro_kernel +( + dim_t n, + dim_t m, + dim_t k, + double* c, + double* a, + double* b, + dim_t ldc, + double* beta +); + +// threshold functions +bool bli_cntx_gemmsup_thresh_is_met_zen5 +( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx +); diff --git a/ref_kernels/1/bli_scalv_ref.c b/ref_kernels/1/bli_scalv_ref.c index 4945b637b0..29d55e6261 100644 --- a/ref_kernels/1/bli_scalv_ref.c +++ b/ref_kernels/1/bli_scalv_ref.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,7 +53,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ if ( PASTEMAC(ch,eq1)( *alpha ) ) return; \ \ /* If alpha is zero, use setv. */ \ - if ( PASTEMAC(ch,eq0)( *alpha ) ) \ + if ( PASTEMAC(ch,eq0)( *alpha ) && n > 0) \ { \ ctype* zero = PASTEMAC(ch,0); \ \ @@ -70,6 +71,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ ); \ return; \ } \ +\ + dim_t n0 = bli_abs(n); \ \ ctype alpha_conj; \ \ @@ -78,14 +81,14 @@ void PASTEMAC3(ch,opname,arch,suf) \ if ( incx == 1 ) \ { \ PRAGMA_SIMD \ - for ( dim_t i = 0; i < n; ++i ) \ + for ( dim_t i = 0; i < n0; ++i ) \ { \ PASTEMAC(ch,scals)( alpha_conj, x[i] ); \ } \ } \ else \ { \ - for ( dim_t i = 0; i < n; ++i ) \ + for ( dim_t i = 0; i < n0; ++i ) \ { \ PASTEMAC(ch,scals)( alpha_conj, *x ); \ \ diff --git a/ref_kernels/1m/bli_packm_cxk_3mis_ref.c b/ref_kernels/1m/bli_packm_cxk_3mis_ref.c deleted file mode 100644 index 0647ec22fb..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_3mis_ref.c +++ /dev/null @@ -1,1954 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_3mis, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_3mis, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_3mis, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_3mis, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_3mis, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_3mis, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_3mis, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ - ctype_r* restrict pi1_rpi = ( ctype_r* )p + 2*is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,copyjri3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,copyri3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,scal2jri3s)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0), *(pi1_rpi + 0) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1), *(pi1_rpi + 1) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2), *(pi1_rpi + 2) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3), *(pi1_rpi + 3) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4), *(pi1_rpi + 4) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5), *(pi1_rpi + 5) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6), *(pi1_rpi + 6) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7), *(pi1_rpi + 7) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8), *(pi1_rpi + 8) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9), *(pi1_rpi + 9) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10), *(pi1_rpi +10) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11), *(pi1_rpi +11) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12), *(pi1_rpi +12) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13), *(pi1_rpi +13) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14), *(pi1_rpi +14) ); \ - PASTEMAC(ch,scal2ri3s)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15), *(pi1_rpi +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - pi1_rpi += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ri3s_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (i )*1; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p_edge_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ -\ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_r, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_i, 1, ldp, \ - cntx, \ - NULL \ - ); \ - PASTEMAC2(chr,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - 0, \ - BLIS_NONUNIT_DIAG, \ - BLIS_DENSE, \ - m_edge, \ - n_edge, \ - zero_r, \ - p_edge_rpi, 1, ldp, \ - cntx, \ - NULL \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_3mis, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/1m/bli_packm_cxk_4mi_ref.c b/ref_kernels/1m/bli_packm_cxk_4mi_ref.c deleted file mode 100644 index d0a4210675..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_4mi_ref.c +++ /dev/null @@ -1,1450 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_4mi, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_4mi, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_4mi, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_4mi, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_4mi, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_4mi, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_4mi, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t is_p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype_r* restrict kappa_r = ( ctype_r* )kappa; \ - ctype_r* restrict kappa_i = ( ctype_r* )kappa + 1; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ - ctype_r* restrict pi1_i = ( ctype_r* )p + is_p; \ -\ - if ( cdim == mnr ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,copyjris)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,copyris)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,copyris)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,scal2jris)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0), *(pi1_i + 0) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1), *(pi1_i + 1) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2), *(pi1_i + 2) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3), *(pi1_i + 3) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4), *(pi1_i + 4) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5), *(pi1_i + 5) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6), *(pi1_i + 6) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7), *(pi1_i + 7) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8), *(pi1_i + 8) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9), *(pi1_i + 9) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10), *(pi1_i +10) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11), *(pi1_i +11) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12), *(pi1_i +12) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13), *(pi1_i +13) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14), *(pi1_i +14) ); \ - PASTEMAC(ch,scal2ris)( *kappa_r, *kappa_i, *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15), *(pi1_i +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - pi1_i += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2ris_mxn) \ - ( \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp, is_p \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - i; \ - const dim_t n_edge = n_max; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (i )*1; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (i )*1; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - j; \ - ctype_r* restrict p_edge_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* restrict p_edge_i = ( ctype_r* )p + is_p + (j )*ldp; \ -\ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_r, 1, ldp \ - ); \ - PASTEMAC(chr,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge_i, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_4mi, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/1m/bli_packm_cxk_rih_ref.c b/ref_kernels/1m/bli_packm_cxk_rih_ref.c deleted file mode 100644 index 9cc32e9a24..0000000000 --- a/ref_kernels/1m/bli_packm_cxk_rih_ref.c +++ /dev/null @@ -1,2498 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_2xk_rih, 2, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_4xk_rih, 4, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_6xk_rih, 6, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_8xk_rih, 8, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_10xk_rih, 10, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_12xk_rih, 12, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_14xk_rih, 14, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, mnr, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - conj_t conja, \ - pack_t schema, \ - dim_t cdim, \ - dim_t n, \ - dim_t n_max, \ - ctype* restrict kappa, \ - ctype* restrict a, inc_t inca, inc_t lda, \ - ctype* restrict p, inc_t ldp, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const inc_t inca2 = 2 * inca; \ - const inc_t lda2 = 2 * lda; \ -\ - ctype* kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype_r* restrict alpha1_r = ( ctype_r* )a; \ - ctype_r* restrict alpha1_i = ( ctype_r* )a + 1; \ - ctype_r* restrict pi1_r = ( ctype_r* )p; \ -\ -\ - if ( cdim == mnr ) \ - { \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - /* This works regardless of conja since we are only copying - the real part. */ \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_r + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_r + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( *(alpha1_r +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jros)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2ros)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( -*(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,copys)( *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,copys)( *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,copys)( *(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jios)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2ios)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), -*(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), -*(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), -*(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), -*(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), -*(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), -*(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), -*(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), -*(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), -*(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), -*(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), -*(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), -*(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), -*(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), -*(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +14*inca2), -*(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +15*inca2), -*(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(chr,add3s)( *(alpha1_r + 0*inca2), *(alpha1_i + 0*inca2), *(pi1_r + 0) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 1*inca2), *(alpha1_i + 1*inca2), *(pi1_r + 1) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 2*inca2), *(alpha1_i + 2*inca2), *(pi1_r + 2) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 3*inca2), *(alpha1_i + 3*inca2), *(pi1_r + 3) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 4*inca2), *(alpha1_i + 4*inca2), *(pi1_r + 4) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 5*inca2), *(alpha1_i + 5*inca2), *(pi1_r + 5) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 6*inca2), *(alpha1_i + 6*inca2), *(pi1_r + 6) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 7*inca2), *(alpha1_i + 7*inca2), *(pi1_r + 7) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 8*inca2), *(alpha1_i + 8*inca2), *(pi1_r + 8) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r + 9*inca2), *(alpha1_i + 9*inca2), *(pi1_r + 9) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +10*inca2), *(alpha1_i +10*inca2), *(pi1_r +10) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +11*inca2), *(alpha1_i +11*inca2), *(pi1_r +11) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +12*inca2), *(alpha1_i +12*inca2), *(pi1_r +12) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +13*inca2), *(alpha1_i +13*inca2), *(pi1_r +13) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +14*inca2), *(alpha1_i +14*inca2), *(pi1_r +14) ); \ - PASTEMAC(chr,add3s)( *(alpha1_r +15*inca2), *(alpha1_i +15*inca2), *(pi1_r +15) ); \ - \ - alpha1_r += lda2; \ - alpha1_i += lda2; \ - pi1_r += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2jrpis)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - else \ - { \ - for ( dim_t k = n; k != 0; --k ) \ - { \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 0*inca), *(pi1_r + 0) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 1*inca), *(pi1_r + 1) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 2*inca), *(pi1_r + 2) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 3*inca), *(pi1_r + 3) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 4*inca), *(pi1_r + 4) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 5*inca), *(pi1_r + 5) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 6*inca), *(pi1_r + 6) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 7*inca), *(pi1_r + 7) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 8*inca), *(pi1_r + 8) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 + 9*inca), *(pi1_r + 9) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +10*inca), *(pi1_r +10) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +11*inca), *(pi1_r +11) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +12*inca), *(pi1_r +12) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +13*inca), *(pi1_r +13) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +14*inca), *(pi1_r +14) ); \ - PASTEMAC(ch,scal2rpis)( *kappa_cast, *(alpha1 +15*inca), *(pi1_r +15) ); \ - \ - alpha1 += lda; \ - pi1_r += ldp; \ - } \ - } \ - } \ - } \ - } \ - else /* if ( cdim < mnr ) */ \ - { \ - PASTEMAC(ch,scal2rihs_mxn) \ - ( \ - schema, \ - conja, \ - cdim, \ - n, \ - kappa, \ - a, inca, lda, \ - p, 1, ldp \ - ); \ -\ - /* if ( cdim < mnr ) */ \ - { \ - const dim_t i = cdim; \ - const dim_t m_edge = mnr - cdim; \ - const dim_t n_edge = n_max; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (i )*1; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ - } \ -\ - if ( n < n_max ) \ - { \ - const dim_t j = n; \ - const dim_t m_edge = mnr; \ - const dim_t n_edge = n_max - n; \ - ctype* restrict p_cast = p; \ - ctype* restrict p_edge = p_cast + (j )*ldp; \ -\ - PASTEMAC(ch,set0s_mxn) \ - ( \ - m_edge, \ - n_edge, \ - p_edge, 1, ldp \ - ); \ - } \ -} - -INSERT_GENTFUNCCO_BASIC3( packm_16xk_rih, 16, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/3/bli_gemmsup_ref.c b/ref_kernels/3/bli_gemmsup_ref.c index 1d3303505f..6cf074ebdc 100644 --- a/ref_kernels/3/bli_gemmsup_ref.c +++ b/ref_kernels/3/bli_gemmsup_ref.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -59,180 +59,225 @@ void PASTEMAC3(ch,opname,arch,suf) \ { \ /* NOTE: This microkernel can actually handle arbitrarily large values of m, n, and k. */ \ + const num_t dt = PASTEMAC(ch,type); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + uint64_t ps_a = bli_auxinfo_ps_a( data ); \ + uint64_t ps_b = bli_auxinfo_ps_b( data ); \ +\ + ctype* restrict abuf = a; \ + ctype* restrict bbuf = b; \ \ if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ { \ /* Traverse c by rows. */ \ - for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t i = 0; i < m; i += MR ) \ { \ - ctype* restrict ci = &c[ i*rs_c ]; \ - ctype* restrict ai = &a[ i*rs_a ]; \ -\ - for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t ii = 0; ii < bli_min( MR, m-i ); ++ii ) \ { \ - ctype* restrict cij = &ci[ j*cs_c ]; \ - ctype* restrict bj = &b [ j*cs_b ]; \ - ctype ab; \ -\ - PASTEMAC(ch,set0s)( ab ); \ -\ - /* Perform a dot product to update the (i,j) element of c. */ \ - for ( dim_t l = 0; l < k; ++l ) \ - { \ - ctype* restrict aij = &ai[ l*cs_a ]; \ - ctype* restrict bij = &bj[ l*rs_b ]; \ -\ - PASTEMAC(ch,dots)( *aij, *bij, ab ); \ - } \ -\ - /* If beta is one, add ab into c. If beta is zero, overwrite c - with the result in ab. Otherwise, scale by beta and accumulate - ab to c. */ \ - if ( PASTEMAC(ch,eq1)( *beta ) ) \ - { \ - PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ - } \ - else if ( PASTEMAC(ch,eq0)( *beta ) ) \ - { \ - PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ - } \ - else \ - { \ - PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + bbuf = b; \ + ctype* restrict ci = c + (i+ii) * rs_c; \ + ctype* restrict ai = abuf + ii * rs_a; \ +\ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + for ( dim_t jj = 0; jj < bli_min( NR, n-j ); ++jj ) \ + { \ + ctype* restrict cij = ci + (j+jj) * cs_c; \ + ctype* restrict bj = bbuf + jj * cs_b; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = ai + l * cs_a; \ + ctype* restrict bij = bj + l * rs_b; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + bbuf += ps_b; \ } \ } \ + abuf += ps_a; \ } \ } \ else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ { \ /* Traverse c by rows. */ \ - for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t i = 0; i < m; i += MR ) \ { \ - ctype* restrict ci = &c[ i*rs_c ]; \ - ctype* restrict ai = &a[ i*rs_a ]; \ -\ - for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t ii = 0; ii < bli_min( MR, m-i ); ++ii ) \ { \ - ctype* restrict cij = &ci[ j*cs_c ]; \ - ctype* restrict bj = &b [ j*cs_b ]; \ - ctype ab; \ -\ - PASTEMAC(ch,set0s)( ab ); \ -\ - /* Perform a dot product to update the (i,j) element of c. */ \ - for ( dim_t l = 0; l < k; ++l ) \ - { \ - ctype* restrict aij = &ai[ l*cs_a ]; \ - ctype* restrict bij = &bj[ l*rs_b ]; \ -\ - PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ - } \ -\ - /* If beta is one, add ab into c. If beta is zero, overwrite c - with the result in ab. Otherwise, scale by beta and accumulate - ab to c. */ \ - if ( PASTEMAC(ch,eq1)( *beta ) ) \ - { \ - PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ - } \ - else if ( PASTEMAC(ch,eq0)( *beta ) ) \ - { \ - PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ - } \ - else \ - { \ - PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + bbuf = b; \ + ctype* restrict ci = c + (i+ii) * rs_c; \ + ctype* restrict ai = abuf + ii * rs_a; \ +\ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + for ( dim_t jj = 0; jj < bli_min( NR, n-j ); ++jj ) \ + { \ + ctype* restrict cij = ci + (j+jj) * cs_c; \ + ctype* restrict bj = bbuf + jj * cs_b; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = ai + l * cs_a; \ + ctype* restrict bij = bj + l * rs_b; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + bbuf += ps_b; \ } \ } \ + abuf += ps_a; \ } \ } \ else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ { \ /* Traverse c by rows. */ \ - for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t i = 0; i < m; i += MR ) \ { \ - ctype* restrict ci = &c[ i*rs_c ]; \ - ctype* restrict ai = &a[ i*rs_a ]; \ -\ - for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t ii = 0; ii < bli_min( MR, m-i ); ++ii ) \ { \ - ctype* restrict cij = &ci[ j*cs_c ]; \ - ctype* restrict bj = &b [ j*cs_b ]; \ - ctype ab; \ -\ - PASTEMAC(ch,set0s)( ab ); \ -\ - /* Perform a dot product to update the (i,j) element of c. */ \ - for ( dim_t l = 0; l < k; ++l ) \ - { \ - ctype* restrict aij = &ai[ l*cs_a ]; \ - ctype* restrict bij = &bj[ l*rs_b ]; \ -\ - PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ - } \ -\ - /* If beta is one, add ab into c. If beta is zero, overwrite c - with the result in ab. Otherwise, scale by beta and accumulate - ab to c. */ \ - if ( PASTEMAC(ch,eq1)( *beta ) ) \ - { \ - PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ - } \ - else if ( PASTEMAC(ch,eq0)( *beta ) ) \ - { \ - PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ - } \ - else \ - { \ - PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + bbuf = b; \ + ctype* restrict ci = c + (i+ii) * rs_c; \ + ctype* restrict ai = abuf + ii * rs_a; \ +\ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + for ( dim_t jj = 0; jj < bli_min( NR, n-j ); ++jj ) \ + { \ + ctype* restrict cij = ci + (j+jj) * cs_c; \ + ctype* restrict bj = bbuf + jj * cs_b; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = ai + l * cs_a; \ + ctype* restrict bij = bj + l * rs_b; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + bbuf += ps_b; \ } \ } \ + abuf += ps_a; \ } \ } \ else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ { \ /* Traverse c by rows. */ \ - for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t i = 0; i < m; i += MR ) \ { \ - ctype* restrict ci = &c[ i*rs_c ]; \ - ctype* restrict ai = &a[ i*rs_a ]; \ -\ - for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t ii = 0; ii < bli_min( MR, m-i ); ++ii ) \ { \ - ctype* restrict cij = &ci[ j*cs_c ]; \ - ctype* restrict bj = &b [ j*cs_b ]; \ - ctype ab; \ -\ - PASTEMAC(ch,set0s)( ab ); \ -\ - /* Perform a dot product to update the (i,j) element of c. */ \ - for ( dim_t l = 0; l < k; ++l ) \ - { \ - ctype* restrict aij = &ai[ l*cs_a ]; \ - ctype* restrict bij = &bj[ l*rs_b ]; \ -\ - PASTEMAC(ch,dots)( *aij, *bij, ab ); \ - } \ -\ - /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ - PASTEMAC(ch,conjs)( ab ); \ -\ - /* If beta is one, add ab into c. If beta is zero, overwrite c - with the result in ab. Otherwise, scale by beta and accumulate - ab to c. */ \ - if ( PASTEMAC(ch,eq1)( *beta ) ) \ - { \ - PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ - } \ - else if ( PASTEMAC(ch,eq0)( *beta ) ) \ - { \ - PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ - } \ - else \ - { \ - PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + bbuf = b; \ + ctype* restrict ci = c + (i+ii) * rs_c; \ + ctype* restrict ai = abuf + ii * rs_a; \ +\ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + for ( dim_t jj = 0; jj < bli_min( NR, n-j ); ++jj ) \ + { \ + ctype* restrict cij = ci + (j+jj) * cs_c; \ + ctype* restrict bj = bbuf + jj * cs_b; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = ai + l * cs_a; \ + ctype* restrict bij = bj + l * rs_b; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + bbuf += ps_b; \ } \ } \ + abuf += ps_a; \ } \ } \ } diff --git a/ref_kernels/bli_cntx_ref.c b/ref_kernels/bli_cntx_ref.c index b0d47d26f1..3272624884 100644 --- a/ref_kernels/bli_cntx_ref.c +++ b/ref_kernels/bli_cntx_ref.c @@ -47,7 +47,7 @@ // -- Level-3 native micro-kernel prototype redefinitions ---------------------- -// -- prototypes for completely generic level-3 microkernels -- +// -- Prototypes for completely generic level-3 microkernels -- #undef gemm_ukr_name #define gemm_ukr_name GENARNAME(gemm) @@ -66,46 +66,7 @@ // -- Level-3 virtual micro-kernel prototype redefinitions --------------------- -// -- 3mh -- - -#undef gemm3mh_ukr_name -#define gemm3mh_ukr_name GENARNAME(gemm3mh) - -// -- 3m1 -- - -#undef gemm3m1_ukr_name -#define gemm3m1_ukr_name GENARNAME(gemm3m1) -#undef gemmtrsm3m1_l_ukr_name -#define gemmtrsm3m1_l_ukr_name GENARNAME(gemmtrsm3m1_l) -#undef gemmtrsm3m1_u_ukr_name -#define gemmtrsm3m1_u_ukr_name GENARNAME(gemmtrsm3m1_u) -#undef trsm3m1_l_ukr_name -#define trsm3m1_l_ukr_name GENARNAME(trsm3m1_l) -#undef trsm3m1_u_ukr_name -#define trsm3m1_u_ukr_name GENARNAME(trsm3m1_u) - -// -- 4mh -- - -#undef gemm4mh_ukr_name -#define gemm4mh_ukr_name GENARNAME(gemm4mh) - -// -- 4mb -- - -#undef gemm4mb_ukr_name -#define gemm4mb_ukr_name GENARNAME(gemm4mb) - -// -- 4m1 -- - -#undef gemm4m1_ukr_name -#define gemm4m1_ukr_name GENARNAME(gemm4m1) -#undef gemmtrsm4m1_l_ukr_name -#define gemmtrsm4m1_l_ukr_name GENARNAME(gemmtrsm4m1_l) -#undef gemmtrsm4m1_u_ukr_name -#define gemmtrsm4m1_u_ukr_name GENARNAME(gemmtrsm4m1_u) -#undef trsm4m1_l_ukr_name -#define trsm4m1_l_ukr_name GENARNAME(trsm4m1_l) -#undef trsm4m1_u_ukr_name -#define trsm4m1_u_ukr_name GENARNAME(trsm4m1_u) +// -- Prototypes for induced method level-3 microkernels -- // -- 1m -- @@ -184,59 +145,6 @@ #undef unpackm_16xk_ker_name #define unpackm_16xk_ker_name GENARNAME(unpackm_16xk) -#undef packm_2xk_3mis_ker_name -#define packm_2xk_3mis_ker_name GENARNAME(packm_2xk_3mis) -#undef packm_4xk_3mis_ker_name -#define packm_4xk_3mis_ker_name GENARNAME(packm_4xk_3mis) -#undef packm_6xk_3mis_ker_name -#define packm_6xk_3mis_ker_name GENARNAME(packm_6xk_3mis) -#undef packm_8xk_3mis_ker_name -#define packm_8xk_3mis_ker_name GENARNAME(packm_8xk_3mis) -#undef packm_10xk_3mis_ker_name -#define packm_10xk_3mis_ker_name GENARNAME(packm_10xk_3mis) -#undef packm_12xk_3mis_ker_name -#define packm_12xk_3mis_ker_name GENARNAME(packm_12xk_3mis) -#undef packm_14xk_3mis_ker_name -#define packm_14xk_3mis_ker_name GENARNAME(packm_14xk_3mis) -#undef packm_16xk_3mis_ker_name -#define packm_16xk_3mis_ker_name GENARNAME(packm_16xk_3mis) - -#undef packm_2xk_4mi_ker_name -#define packm_2xk_4mi_ker_name GENARNAME(packm_2xk_4mi) -#undef packm_3xk_4mi_ker_name -#define packm_3xk_4mi_ker_name GENARNAME(packm_3xk_4mi) -#undef packm_4xk_4mi_ker_name -#define packm_4xk_4mi_ker_name GENARNAME(packm_4xk_4mi) -#undef packm_6xk_4mi_ker_name -#define packm_6xk_4mi_ker_name GENARNAME(packm_6xk_4mi) -#undef packm_8xk_4mi_ker_name -#define packm_8xk_4mi_ker_name GENARNAME(packm_8xk_4mi) -#undef packm_10xk_4mi_ker_name -#define packm_10xk_4mi_ker_name GENARNAME(packm_10xk_4mi) -#undef packm_12xk_4mi_ker_name -#define packm_12xk_4mi_ker_name GENARNAME(packm_12xk_4mi) -#undef packm_14xk_4mi_ker_name -#define packm_14xk_4mi_ker_name GENARNAME(packm_14xk_4mi) -#undef packm_16xk_4mi_ker_name -#define packm_16xk_4mi_ker_name GENARNAME(packm_16xk_4mi) - -#undef packm_2xk_rih_ker_name -#define packm_2xk_rih_ker_name GENARNAME(packm_2xk_rih) -#undef packm_4xk_rih_ker_name -#define packm_4xk_rih_ker_name GENARNAME(packm_4xk_rih) -#undef packm_6xk_rih_ker_name -#define packm_6xk_rih_ker_name GENARNAME(packm_6xk_rih) -#undef packm_8xk_rih_ker_name -#define packm_8xk_rih_ker_name GENARNAME(packm_8xk_rih) -#undef packm_10xk_rih_ker_name -#define packm_10xk_rih_ker_name GENARNAME(packm_10xk_rih) -#undef packm_12xk_rih_ker_name -#define packm_12xk_rih_ker_name GENARNAME(packm_12xk_rih) -#undef packm_14xk_rih_ker_name -#define packm_14xk_rih_ker_name GENARNAME(packm_14xk_rih) -#undef packm_16xk_rih_ker_name -#define packm_16xk_rih_ker_name GENARNAME(packm_16xk_rih) - #undef packm_2xk_1er_ker_name #define packm_2xk_1er_ker_name GENARNAME(packm_2xk_1er) #undef packm_4xk_1er_ker_name @@ -340,7 +248,14 @@ PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ } +// -- Helper function for 1m --------------------------------------------------- +void GENBAINAME(cntx_init_blkszs) + ( + ind_t method, + num_t dt, + cntx_t* cntx + ); // ----------------------------------------------------------------------------- @@ -404,8 +319,8 @@ void GENBARNAME(cntx_init) // NOTE: We set the virtual micro-kernel slots to contain the addresses // of the native micro-kernels. In general, the ukernels in the virtual // ukernel slots are always called, and if the function called happens to - // be a virtual micro-kernel, it will then know to find its native - // ukernel in the native ukernel slots. + // be a virtual micro-kernel, it will then know to find its native ukernel + // (i.e., in the native ukernel slots). gen_func_init( &funcs[ BLIS_GEMM_UKR ], gemm_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm_l_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm_u_ukr_name ); @@ -700,10 +615,6 @@ void GENBARNAME(cntx_init) // -- Set miscellaneous fields --------------------------------------------- bli_cntx_set_method( BLIS_NAT, cntx ); - - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS, cntx ); - bli_cntx_set_schema_c_panel( BLIS_NOT_PACKED, cntx ); } // ----------------------------------------------------------------------------- @@ -711,7 +622,6 @@ void GENBARNAME(cntx_init) void GENBAINAME(cntx_init) ( ind_t method, - num_t dt, cntx_t* cntx ) { @@ -728,41 +638,7 @@ void GENBAINAME(cntx_init) funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); - // 3mh, 4mh, and 4mb do not not support trsm. - bli_func_init_null( &funcs[ BLIS_GEMMTRSM_L_UKR ] ); - bli_func_init_null( &funcs[ BLIS_GEMMTRSM_U_UKR ] ); - bli_func_init_null( &funcs[ BLIS_TRSM_L_UKR ] ); - bli_func_init_null( &funcs[ BLIS_TRSM_U_UKR ] ); - - if ( method == BLIS_3MH ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm3mh_ukr_name ); - } - else if ( method == BLIS_3M1 ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm3m1_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm3m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm3m1_u_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_L_UKR ], trsm3m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_U_UKR ], trsm3m1_u_ukr_name ); - } - else if ( method == BLIS_4MH ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4mh_ukr_name ); - } - else if ( method == BLIS_4M1B ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4mb_ukr_name ); - } - else if ( method == BLIS_4M1A ) - { - gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm4m1_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm4m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm4m1_u_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_L_UKR ], trsm4m1_l_ukr_name ); - gen_func_init_co( &funcs[ BLIS_TRSM_U_UKR ], trsm4m1_u_ukr_name ); - } - else if ( method == BLIS_1M ) + if ( method == BLIS_1M ) { gen_func_init_co( &funcs[ BLIS_GEMM_UKR ], gemm1m_ukr_name ); gen_func_init_co( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm1m_l_ukr_name ); @@ -781,7 +657,14 @@ void GENBAINAME(cntx_init) // For 1m, we employ an optimization which requires that we copy the native // real domain gemm ukernel function pointers to the corresponding real - // domain slots in the virtual gemm ukernel func_t. + // domain slots in the virtual gemm ukernel func_t. This optimization allows + // us to, under certain conditions, adjust various parameters within the gemm + // macrokernel so that the real-domain macrokernel (which will query and use + // the real-domain virtual gemm ukernel) can be called instead of calling the + // complex-domain macrokernel and the corresponding complex-domain virtual + // microkernel. The non-optimized code path would require an extra level of + // function call overhead, which can be avoided in most cases (i.e., when + // beta has a zero imaginary component and C is either row- or column-stored). if ( method == BLIS_1M ) { func_t* gemm_nat_ukrs = bli_cntx_get_l3_nat_ukrs( BLIS_GEMM_UKR, cntx ); @@ -802,40 +685,7 @@ void GENBAINAME(cntx_init) bli_func_init_null( &funcs[ i ] ); } - if ( method == BLIS_3MH || method == BLIS_4MH ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_rih_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_rih_ker_name ); - } - else if ( method == BLIS_3M1 ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_3mis_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_3mis_ker_name ); - } - else if ( method == BLIS_4M1A || method == BLIS_4M1B ) - { - gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_6XK_KER ], packm_6xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_8XK_KER ], packm_8xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_10XK_KER ], packm_10xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_12XK_KER ], packm_12xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_14XK_KER ], packm_14xk_4mi_ker_name ); - gen_func_init_co( &funcs[ BLIS_PACKM_16XK_KER ], packm_16xk_4mi_ker_name ); - } - else if ( method == BLIS_1M ) + if ( method == BLIS_1M ) { gen_func_init_co( &funcs[ BLIS_PACKM_2XK_KER ], packm_2xk_1er_ker_name ); gen_func_init_co( &funcs[ BLIS_PACKM_4XK_KER ], packm_4xk_1er_ker_name ); @@ -865,191 +715,75 @@ void GENBAINAME(cntx_init) // Modify the context with cache and register blocksizes (and multiples) // appropriate for the current induced method. - if ( method == BLIS_3MH ) + if ( method == BLIS_1M ) { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); + //const bool is_pb = FALSE; + + // Call a helper function to initialize blocksizes for each complex + // datatype. + GENBAINAME(cntx_init_blkszs)( method, BLIS_SCOMPLEX, cntx ); + GENBAINAME(cntx_init_blkszs)( method, BLIS_DCOMPLEX, cntx ); } - else if ( method == BLIS_3M1 ) + else // if ( method == BLIS_NAT ) { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 3.0, 3.0, - BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); + // No change in blocksizes needed for native execution. } - else if ( method == BLIS_4MH ) +} + +// ----------------------------------------------------------------------------- + +void GENBAINAME(cntx_init_blkszs) + ( + ind_t method, + num_t dt, + cntx_t* cntx + ) +{ + // We MUST set the induced method in the context prior to calling + // bli_cntx_l3_vir_ukr_prefers_cols_dt() because that function queries + // the induced method. That function needs the induced method value in + // order to determine whether to evaluate the "prefers column storage" + // predicate using the storage preference of the kernel for dt, or + // the storage preference of the kernel for the real projection of + // dt. Failing to set the induced method here can lead to strange + // undefined behavior at runtime if the native complex kernel's + // storage preference happens to not equal that of the native real + // kernel. + bli_cntx_set_method( method, cntx ); + + // Initialize the blocksizes according to the micro-kernel preference as + // well as the algorithm. + if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) { + // This branch is used for algorithm 1m_c_bp. + bli_cntx_set_ind_blkszs ( - method, 6, + method, dt, 6, BLIS_NC, 1.0, 1.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 1.0, 1.0, + BLIS_KC, 2.0, 2.0, // halve kc... + BLIS_MC, 2.0, 2.0, // halve mc... BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, + BLIS_MR, 2.0, 1.0, // ...and mr (but NOT packmr) BLIS_KR, 1.0, 1.0, cntx ); } - else if ( method == BLIS_4M1B ) - { - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 2.0, 2.0, - BLIS_KC, 1.0, 1.0, - BLIS_MC, 2.0, 2.0, - BLIS_NR, 1.0, 1.0, - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); - } - else if ( method == BLIS_4M1A ) + else // if ( bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, cntx ) ) { + // This branch is used for algorithm 1m_r_bp. + bli_cntx_set_ind_blkszs ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 2.0, 2.0, + method, dt, 6, + BLIS_NC, 2.0, 2.0, // halve nc... + BLIS_KC, 2.0, 2.0, // halve kc... BLIS_MC, 1.0, 1.0, - BLIS_NR, 1.0, 1.0, + BLIS_NR, 2.0, 1.0, // ...and nr (but NOT packnr) BLIS_MR, 1.0, 1.0, BLIS_KR, 1.0, 1.0, cntx ); } - else if ( method == BLIS_1M ) - { - const bool is_pb = FALSE; - - // We MUST set the induced method in the context prior to calling - // bli_cntx_l3_ukr_prefers_cols_dt() because that function queries - // the induced method. It needs the induced method value in order - // to determine whether to evaluate the "prefers column storage" - // predicate using the storage preference of the kernel for dt, or - // the storage preference of the kernel for the real projection of - // dt. Failing to set the induced method here can lead to strange - // undefined behavior at runtime if the native complex kernel's - // storage preference happens to not equal that of the native real - // kernel. - bli_cntx_set_method( method, cntx ); - - // Initialize the blocksizes according to the micro-kernel preference as - // well as the algorithm. - if ( bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ) ) - { - // This branch is used for algorithms 1m_c_bp, 1m_r_pb. - - // Set the pack_t schemas for the c_bp or r_pb algorithms. - if ( !is_pb ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1E, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1R, cntx ); - } - else // if ( is_pb ) - { - bli_cntx_set_schema_b_panel( BLIS_PACKED_ROW_PANELS_1R, cntx ); - bli_cntx_set_schema_a_block( BLIS_PACKED_COL_PANELS_1E, cntx ); - } - - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 1.0, 1.0, - BLIS_KC, 2.0, 2.0, // halve kc... - BLIS_MC, 2.0, 2.0, // halve mc... - BLIS_NR, 1.0, 1.0, - BLIS_MR, 2.0, 1.0, // ...and mr (but NOT packmr) - BLIS_KR, 1.0, 1.0, - cntx - ); - } - else // if ( bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, cntx ) ) - { - // This branch is used for algorithms 1m_r_bp, 1m_c_pb. - - // Set the pack_t schemas for the r_bp or c_pb algorithms. - if ( !is_pb ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_1R, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_1E, cntx ); - } - else // if ( is_pb ) - { - bli_cntx_set_schema_b_panel( BLIS_PACKED_ROW_PANELS_1E, cntx ); - bli_cntx_set_schema_a_block( BLIS_PACKED_COL_PANELS_1R, cntx ); - } - - bli_cntx_set_ind_blkszs - ( - method, 6, - BLIS_NC, 2.0, 2.0, // halve nc... - BLIS_KC, 2.0, 2.0, // halve kc... - BLIS_MC, 1.0, 1.0, - BLIS_NR, 2.0, 1.0, // ...and nr (but NOT packnr) - BLIS_MR, 1.0, 1.0, - BLIS_KR, 1.0, 1.0, - cntx - ); - } - } - else // if ( method == BLIS_NAT ) - { - // No change in blocksizes needed for native execution. - } - - - // -- Set misc. other fields ----------------------------------------------- - - if ( method == BLIS_3MH ) - { - // Schemas vary with _stage(). - } - else if ( method == BLIS_3M1 ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_3MI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_3MI, cntx ); - } - else if ( method == BLIS_4MH ) - { - // Schemas vary with _stage(). - } - else if ( method == BLIS_4M1A || method == BLIS_4M1B ) - { - bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS_4MI, cntx ); - bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS_4MI, cntx ); - } - else if ( method == BLIS_1M ) - { - //const bool is_pb = FALSE; - - // Set the anti-preference field to TRUE when executing a panel-block - // algorithm, and FALSE otherwise. This will cause higher-level generic - // code to establish (if needed) disagreement between the storage of C and - // the micro-kernel output preference so that the two will come back into - // agreement in the panel-block macro-kernel (which implemented in terms - // of the block-panel macro-kernel with some induced transpositions). - //bli_cntx_set_anti_pref( is_pb, cntx ); - } - else // if ( method == BLIS_NAT ) - { - } } diff --git a/ref_kernels/ind/bli_gemm3m1_ref.c b/ref_kernels/ind/bli_gemm3m1_ref.c deleted file mode 100644 index a0a935a994..0000000000 --- a/ref_kernels/ind/bli_gemm3m1_ref.c +++ /dev/null @@ -1,336 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ab_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_rpi[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ab; \ - inc_t cs_ab; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ - ctype_r* restrict a_rpi = ( ctype_r* )a + 2*is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_rpi = ( ctype_r* )b + 2*is_b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incab, ldab; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 3m method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ab = n; n_iter = m; incc = cs_c; \ - cs_ab = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ab = 1; n_iter = n; incc = rs_c; \ - cs_ab = m; n_elem = m; ldc = cs_c; \ - } \ - incab = 1; \ - ldab = n_elem; \ -\ -\ - /* The following gemm micro-kernel calls implement all "phases" of the - 3m method: - - c = beta * c; - c_r += + a_r * b_r - a_i * b_i; - c_i += (a_r + a_i)(b_r + b_i) - a_r * b_r - a_i * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - /* ab_r = alpha_r * a_r * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ab_r, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_rpi, b_rpi, data ); \ -\ - /* ab_i = alpha_r * a_i * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_i, \ - zero_r, \ - ab_i, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* ct_i = alpha_r * a_ri * b_ri; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_rpi, \ - b_rpi, \ - zero_r, \ - ab_rpi, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix products stored in ab_r, - ab_i, and ab_rpi depends on the value of beta. */ \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ab_r - ab_i; - c_i = c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ab_r - ab_i; - c_i = c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,addris)( gamma11t_r, \ - gamma11t_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ab_r - ab_i; - c_i = beta_r * c_i + ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if ( PASTEMAC(chr,eq0)( beta_r ) ) */ \ - { \ - /* c_r = ab_r - ab_i; - c_i = ab_rpi - ab_r - ab_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r alphabeta11_r = *(ab_r + i*incab + j*ldab); \ - const ctype_r alphabeta11_i = *(ab_i + i*incab + j*ldab); \ - const ctype_r alphabeta11_rpi = *(ab_rpi + i*incab + j*ldab); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ - ctype_r gamma11t_r; \ - ctype_r gamma11t_i; \ -\ - PASTEMAC(ch,copyris)( alphabeta11_r, \ - -alphabeta11_r, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,subris)( alphabeta11_i, \ - alphabeta11_i, \ - gamma11t_r, \ - gamma11t_i ); \ -\ - PASTEMAC(chr,adds)( alphabeta11_rpi, \ - gamma11t_i ); \ -\ - PASTEMAC(ch,copyris)( gamma11t_r, \ - gamma11t_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm3m1, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm3mh_ref.c b/ref_kernels/ind/bli_gemm3mh_ref.c deleted file mode 100644 index 1f242bc255..0000000000 --- a/ref_kernels/ind/bli_gemm3mh_ref.c +++ /dev/null @@ -1,297 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - ctype_r* restrict a_cast = ( ctype_r* )a; \ -\ - ctype_r* restrict b_cast = ( ctype_r* )b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - const pack_t schema = bli_auxinfo_schema_a( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 3mh method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel call implements one "phase" of the - 3m method: - - c = beta * c; - c_r += + a_r * b_r - a_i * b_i; - c_i += (a_r + a_i)(b_r + b_i) - a_r * b_r - a_i * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - /* ct = alpha_r * a * b; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_cast, \ - b_cast, \ - zero_r, \ - ct, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: ct", 4, 4, ct, rs_ct, cs_ct, "%4.1f", "" );*/ \ -\ - /* How we accumulate the intermediate matrix product stored in ct - depends on (a) the schemas of A and B (they are always the same), - and (b) the value of beta. */ \ - if ( bli_is_ro_packed( schema ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t, \ - -gamma11t, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,subs)( gamma11t, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct; - c_i = beta_r * c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( -gamma11t, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct; - c_i = -ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else if ( bli_is_io_packed( schema ) ) \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r - ct; - c_i = c_i - ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,subs)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,subs)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = -ct; - c_i = -ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_r ); \ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_rpi_packed( schema ) ) */ \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + 0; - c_i = c_i + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = 0; - c_i = ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,set0s)( *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t, *gamma11_i ); \ - } \ - } \ - } \ -\ -/*PASTEMAC(ch,fprintm)( stdout, "gemm3mh_ukr: c", 4, 4, c, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ -\ -/*PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: b1", k, n, b_cast, n, 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm3mh_ukr: a1", m, k, a_cast, 1, m, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm3mh, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_gemm4m1_ref.c b/ref_kernels/ind/bli_gemm4m1_ref.c deleted file mode 100644 index e214985156..0000000000 --- a/ref_kernels/ind/bli_gemm4m1_ref.c +++ /dev/null @@ -1,291 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ct_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - ctype_r m_alpha_r = -(*alpha_r); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: ap_r", m, k, \ - a_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: ap_i", m, k, \ - a_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: bp_r", k, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4m1_ukr: bp_i", k, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4m method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel calls implement all "phases" of - the 4m method: - - c = beta * c; - c_r += a_r * b_r - a_i * b_i; - c_i += a_r * b_i + a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - bli_auxinfo_set_next_ab( a_r, b_i, data ); \ -\ - /* ct_r = alpha_r * a_r * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_i, b_r, data ); \ -\ - /* ct_i = alpha_r * a_r * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_i, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - /* ct_i += alpha_r * a_i * b_r; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_r, \ - one_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* ct_r += -alpha_r * a_i * b_i; */ \ - rgemm_ukr \ - ( \ - k, \ - &m_alpha_r, \ - a_i, \ - b_i, \ - one_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct_r - and ct_i depends on the value of beta. */ \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct_r; */ \ - /* c_i = beta_r * c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4m1, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm4mb_ref.c b/ref_kernels/ind/bli_gemm4mb_ref.c deleted file mode 100644 index 12a6d46649..0000000000 --- a/ref_kernels/ind/bli_gemm4mb_ref.c +++ /dev/null @@ -1,345 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ct_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - ctype_r m_alpha_r = -PASTEMAC(ch,real)( *alpha ); \ -\ - const pack_t schema_b = bli_auxinfo_schema_b( data ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4mb method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ -\ - if ( bli_is_ro_packed( schema_b ) ) \ - { \ - /* The following gemm micro-kernel calls implement the first half of - the 4mb method (which uses b_r): - - c = beta * c; - c_r += a_r * b_r; - c_i += a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ - bli_auxinfo_set_next_ab( a_i, b_r, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_r, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_i, \ - b_r, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ - } \ - else /* if ( bli_is_io_packed( schema_b ) ) */ \ - { \ - /* The following gemm micro-kernel calls implement the second half of - the 4mb method (which uses b_i): - - c_r += -a_i * b_i; - c_i += a_r * b_i; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ - bli_auxinfo_set_next_ab( a_i, b_i, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_r, \ - b_i, \ - zero_r, \ - ct_i, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - rgemm_ukr \ - ( \ - k, \ - &m_alpha_r, \ - a_i, \ - b_i, \ - zero_r, \ - ct_r, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ - } \ -\ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct_r - and ct_i depends on (a) the schema of B, and (b) the value of - beta. */ \ - if ( bli_is_ro_packed( schema_b ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(ch,xpbyris)( gamma11t_r, \ - gamma11t_i, \ - beta_r, \ - beta_i, \ - *gamma11_r, \ - *gamma11_i ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct_r; */ \ - /* c_i = beta_r * c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t_r, beta_r, *gamma11_r ); \ - PASTEMAC(chr,xpbys)( gamma11t_i, beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_io_packed( schema_b ) ) */ \ - { \ - /* NOTE: If this branch executes, it means we are in the second - half of the 4mb computation in which we multiply the b_i - sub-panel by the entire block of A. Here, we know that beta - will either be equal to one (for interior cases within gemm - macro-kernel), or zero (for edge cases). */ \ -\ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct_r; */ \ - /* c_i = c_i + ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,adds)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct_r; */ \ - /* c_i = ct_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t_r = *(ct_r + i*incct + j*ldct); \ - const ctype_r gamma11t_i = *(ct_i + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t_r, *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t_i, *gamma11_i ); \ - } \ - } \ - } \ -\ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: b1_r", k, n, b_r, n, 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: b1_i", k, n, b_i, n, 1, "%4.1f", "" );*/ \ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: a1_r", m, k, a_r, 1, m, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: a1_i", m, k, a_i, 1, m, "%4.1f", "" );*/ \ -/*PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: ct_r", 8, 6, ct_r, rs_ct, cs_ct, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemm4mb_ukr: ct_i", 8, 6, ct_i, rs_ct, cs_ct, "%4.1f", "" );*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4mb, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - diff --git a/ref_kernels/ind/bli_gemm4mh_ref.c b/ref_kernels/ind/bli_gemm4mh_ref.c deleted file mode 100644 index afa76ce761..0000000000 --- a/ref_kernels/ind/bli_gemm4mh_ref.c +++ /dev/null @@ -1,286 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict beta, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - inc_t rs_ct; \ - inc_t cs_ct; \ -\ - ctype_r* restrict a_cast = ( ctype_r* )a; \ -\ - ctype_r* restrict b_cast = ( ctype_r* )b; \ -\ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ -\ - ctype_r* restrict alpha_r = &PASTEMAC(ch,real)( *alpha ); \ - ctype_r* restrict alpha_i = &PASTEMAC(ch,imag)( *alpha ); \ -\ - const ctype_r beta_r = PASTEMAC(ch,real)( *beta ); \ - const ctype_r beta_i = PASTEMAC(ch,imag)( *beta ); \ -\ - const pack_t schema_a = bli_auxinfo_schema_a( data ); \ - const pack_t schema_b = bli_auxinfo_schema_b( data ); \ -\ - dim_t n_iter; \ - dim_t n_elem; \ -\ - inc_t incc, ldc; \ - inc_t incct, ldct; \ -\ - dim_t i, j; \ -\ -\ - /* SAFETY CHECK: The higher level implementation should never - allow an alpha with non-zero imaginary component to be passed - in, because it can't be applied properly using the 4mh method. - If alpha is not real, then something is very wrong. */ \ - if ( !PASTEMAC(chr,eq0)( *alpha_i ) ) \ - bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); \ -\ -\ - /* An optimization: Set local strides and loop bounds based on the - strides of c, so that (a) the micro-kernel accesses ct the same - way it would if it were updating c directly, and (b) c is updated - contiguously. For c with general stride, we access ct the same way - we would as if it were column-stored. */ \ - if ( bli_is_row_stored( rs_c, cs_c ) ) \ - { \ - rs_ct = n; n_iter = m; incc = cs_c; \ - cs_ct = 1; n_elem = n; ldc = rs_c; \ - } \ - else /* column-stored or general stride */ \ - { \ - rs_ct = 1; n_iter = n; incc = rs_c; \ - cs_ct = m; n_elem = m; ldc = cs_c; \ - } \ - incct = 1; \ - ldct = n_elem; \ -\ -\ - /* The following gemm micro-kernel call implement one "phase" of the - 4m method: - - c = beta * c; - c_r += a_r * b_r - a_i * b_i; - c_i += a_r * b_i + a_i * b_r; - - NOTE: Scaling by alpha_r is not shown above, but is implemented - below. */ \ -\ -\ - /* ct = alpha_r * a * b; */ \ - rgemm_ukr \ - ( \ - k, \ - alpha_r, \ - a_cast, \ - b_cast, \ - zero_r, \ - ct, rs_ct, cs_ct, \ - data, \ - cntx \ - ); \ -\ -\ - /* How we accumulate the intermediate matrix product stored in ct - depends on (a) the schemas of A and B, and (b) the value of - beta. */ \ - if ( bli_is_ro_packed( schema_a ) && \ - bli_is_ro_packed( schema_b ) ) \ - { \ - if ( !PASTEMAC(chr,eq0)( beta_i ) ) \ - { \ - /* c = beta * c; - c_r = c_r + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(ch,scals)( *beta, *gamma11 ); \ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - } \ - } \ - else if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + ct; - c_i = c_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_r ); \ - } \ - } \ - else if ( !PASTEMAC(chr,eq0)( beta_r ) ) \ - { \ - /* c_r = beta_r * c_r + ct; - c_i = beta_r * c_i; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,xpbys)( gamma11t, beta_r, *gamma11_r ); \ - PASTEMAC(chr,scals)( beta_r, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = ct; - c_i = 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( gamma11t, *gamma11_r ); \ - PASTEMAC(chr,set0s)( *gamma11_i ); \ - } \ - } \ - } \ - else if ( ( bli_is_ro_packed( schema_a ) && \ - bli_is_io_packed( schema_b ) ) || \ - ( bli_is_io_packed( schema_a ) && \ - bli_is_ro_packed( schema_b ) ) \ - ) \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r + 0; - c_i = c_i + ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,adds)( gamma11t, *gamma11_i ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = 0; - c_i = ct; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,set0s)( *gamma11_r ); \ - PASTEMAC(chr,copys)( gamma11t, *gamma11_i ); \ - } \ - } \ - } \ - else /* if ( bli_is_io_packed( schema_a ) && \ - bli_is_io_packed( schema_b ) ) */ \ - { \ - if ( PASTEMAC(chr,eq1)( beta_r ) ) \ - { \ - /* c_r = c_r - ct; - c_i = c_i + 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ -\ - PASTEMAC(chr,subs)( gamma11t, *gamma11_r ); \ - } \ - } \ - else /* if PASTEMAC(chr,eq0)( beta_r ) */ \ - { \ - /* c_r = -ct; - c_i = 0; */ \ - for ( j = 0; j < n_iter; ++j ) \ - for ( i = 0; i < n_elem; ++i ) \ - { \ - const ctype_r gamma11t = *(ct + i*incct + j*ldct); \ - ctype* restrict gamma11 = c + i*incc + j*ldc ; \ - ctype_r* restrict gamma11_r = &PASTEMAC(ch,real)( *gamma11 ); \ - ctype_r* restrict gamma11_i = &PASTEMAC(ch,imag)( *gamma11 ); \ -\ - PASTEMAC(chr,copys)( -gamma11t, *gamma11_r ); \ - PASTEMAC(chr,set0s)( *gamma11_i ); \ - } \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( gemm4mh, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_gemmtrsm1m_ref.c b/ref_kernels/ind/bli_gemmtrsm1m_ref.c index 7def665de6..5cfaee9ec6 100644 --- a/ref_kernels/ind/bli_gemmtrsm1m_ref.c +++ b/ref_kernels/ind/bli_gemmtrsm1m_ref.c @@ -78,7 +78,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ const dim_t k2 = 2 * k; \ \ diff --git a/ref_kernels/ind/bli_gemmtrsm3m1_ref.c b/ref_kernels/ind/bli_gemmtrsm3m1_ref.c deleted file mode 100644 index 820a0ec2ba..0000000000 --- a/ref_kernels/ind/bli_gemmtrsm3m1_ref.c +++ /dev/null @@ -1,248 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, trsmkerid ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a1x, \ - ctype* restrict a11, \ - ctype* restrict bx1, \ - ctype* restrict b11, \ - ctype* restrict c11, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - PASTECH(ch,trsm_ukr_ft) \ - ctrsm_vir_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, trsmkerid, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - ctype_r ab_r[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - ctype_r ab_i[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_r ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const inc_t rs_ab = 1; \ - const inc_t cs_ab = mr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a1x_r = ( ctype_r* )a1x; \ - ctype_r* restrict a1x_i = ( ctype_r* )a1x + is_a; \ - ctype_r* restrict a1x_ri = ( ctype_r* )a1x + 2*is_a; \ -\ - ctype_r* restrict bx1_r = ( ctype_r* )bx1; \ - ctype_r* restrict bx1_i = ( ctype_r* )bx1 + is_b; \ - ctype_r* restrict bx1_ri = ( ctype_r* )bx1 + 2*is_b; \ -\ - ctype_r* restrict b11_r = ( ctype_r* )b11; \ - ctype_r* restrict b11_i = ( ctype_r* )b11 + is_b; \ - ctype_r* restrict b11_ri = ( ctype_r* )b11 + 2*is_b; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - ctype_r alpha_r = PASTEMAC(ch,real)( *alpha ); \ - ctype_r alpha_i = PASTEMAC(ch,imag)( *alpha ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t i, j; \ -\ -\ - /* Copy the contents of c to a temporary buffer ct. */ \ - if ( !PASTEMAC(chr,eq0)( alpha_i ) ) \ - { \ - /* We can handle a non-zero imaginary component on alpha, but to do - so we have to manually scale b and then use alpha == 1 for the - micro-kernel calls. */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - PASTEMAC(ch,scalris)( alpha_r, \ - alpha_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Use alpha.r == 1.0. */ \ - alpha_r = *one_r; \ - } \ -\ -\ - /* lower: - b11.r = alpha.r * b11.r - ( + a10.r * b01.r - a10.i * b01.i ); - b11.i = alpha.r * b11.i - ( a10.ri * b01.ri - a10.r * b01.r - a10.i * b01.i ); - - upper: - b11.r = alpha.r * b11.r - ( + a12.r * b21.r - a12.i * b21.i ); - b11.i = alpha.r * b11.i - ( a12.ri * b21.ri - a12.r * b21.r - a12.i * b21.i ); */ \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_i, data ); \ -\ - /* lower: ab.r = a10.r * b01.r; - upper: ab.r = a12.r * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_r, \ - bx1_r, \ - zero_r, \ - ab_r, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_ri, bx1_ri, data ); \ -\ - /* lower: ab.i = a10.i * b01.i; - upper: ab.i = a12.i * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_i, \ - bx1_i, \ - zero_r, \ - ab_i, rs_ab, cs_ab, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* lower: b11.i = alpha.r * b11.i - a12.ri * b21.ri; - upper: b11.i = alpha.r * b11.i - a12.ri * b21.ri; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_ri, \ - bx1_ri, \ - &alpha_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ -\ - /* b11.r = alpha.r * b11.r - ab.r; - b11.r = b11.r + ab.i; - b11.i = b11.i + ab.r; - b11.i = b11.i + ab.i; */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r alphabeta_r = *(ab_r + i*rs_ab + j*cs_ab); \ - ctype_r alphabeta_i = *(ab_i + i*rs_ab + j*cs_ab); \ - ctype_r beta11_r = *(b11_r + i*rs_b + j*cs_b); \ - ctype_r beta11_i = *(b11_i + i*rs_b + j*cs_b); \ -\ - PASTEMAC(chr,scals)( alpha_r, beta11_r ); \ -\ - PASTEMAC(chr,subs)( alphabeta_r, beta11_r ); \ - PASTEMAC(chr,adds)( alphabeta_i, beta11_r ); \ - PASTEMAC(chr,adds)( alphabeta_r, beta11_i ); \ - PASTEMAC(chr,adds)( alphabeta_i, beta11_i ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(ch,copyris)( beta11_r, \ - beta11_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Update the ri part of b11. */ \ - PASTEMAC(chr,add3s)( beta11_r, \ - beta11_i, \ - *(b11_ri + i*rs_b + j*cs_b) ); \ - } \ -\ -\ - /* b11 = inv(a11) * b11; - c11 = b11; */ \ - ctrsm_vir_ukr \ - ( \ - a11, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ -\ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_r after", m, n, \ - b11_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_i after", m, n, \ - b11_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b01_r", k, n, \ - b01_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b01_i", k, n, \ - b01_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_r", m, n, \ - b11_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm3m1_l_ukr: b11_i", m, n, \ - b11_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC3( gemmtrsm3m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) -INSERT_GENTFUNCCO_BASIC3( gemmtrsm3m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_U_UKR ) diff --git a/ref_kernels/ind/bli_gemmtrsm4m1_ref.c b/ref_kernels/ind/bli_gemmtrsm4m1_ref.c deleted file mode 100644 index 0988c457da..0000000000 --- a/ref_kernels/ind/bli_gemmtrsm4m1_ref.c +++ /dev/null @@ -1,230 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf, trsmkerid ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - dim_t k, \ - ctype* restrict alpha, \ - ctype* restrict a1x, \ - ctype* restrict a11, \ - ctype* restrict bx1, \ - ctype* restrict b11, \ - ctype* restrict c11, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - PASTECH(chr,gemm_ukr_ft) \ - rgemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt_r, BLIS_GEMM_UKR, cntx ); \ -\ - PASTECH(ch,trsm_ukr_ft) \ - ctrsm_vir_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, trsmkerid, cntx ); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a1x_r = ( ctype_r* )a1x; \ - ctype_r* restrict a1x_i = ( ctype_r* )a1x + is_a; \ -\ - ctype_r* restrict bx1_r = ( ctype_r* )bx1; \ - ctype_r* restrict bx1_i = ( ctype_r* )bx1 + is_b; \ -\ - ctype_r* restrict b11_r = ( ctype_r* )b11; \ - ctype_r* restrict b11_i = ( ctype_r* )b11 + is_b; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - ctype_r* restrict one_r = PASTEMAC(chr,1); \ - ctype_r* restrict minus_one_r = PASTEMAC(chr,m1); \ -\ - /* A hack to avoid a 'restrict' warning triggered by passing in the - same address (one_r) for both alpha and beta when calling the last - of the four matrix products. We now use one_r for alpha and this - new local variable, onel, for beta. (See issue #328.) */ \ - ctype_r onel; \ - ctype_r* restrict onel_r = &onel; \ - PASTEMAC(chr,set1s)( onel ); \ -\ - ctype_r alpha_r = PASTEMAC(ch,real)( *alpha ); \ - ctype_r alpha_i = PASTEMAC(ch,imag)( *alpha ); \ -\ - void* a_next = bli_auxinfo_next_a( data ); \ - void* b_next = bli_auxinfo_next_b( data ); \ -\ - dim_t i, j; \ -\ -/* -printf( "gemmtrsm4m1_l_ukr: is_a = %lu is_b = %lu\n", is_a, is_b ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: a1x11p_r", m, k+m, \ - a1x_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: a1x11p_i", m, k+m, \ - a1x_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - /* Copy the contents of c to a temporary buffer ct. */ \ - if ( !PASTEMAC(chr,eq0)( alpha_i ) ) \ - { \ - /* We can handle a non-zero imaginary component on alpha, but to do - so we have to manually scale b and then use alpha == 1 for the - micro-kernel calls. */ \ - for ( i = 0; i < m; ++i ) \ - for ( j = 0; j < n; ++j ) \ - PASTEMAC(ch,scalris)( alpha_r, \ - alpha_i, \ - *(b11_r + i*rs_b + j*cs_b), \ - *(b11_i + i*rs_b + j*cs_b) ); \ -\ - /* Use alpha.r == 1.0. */ \ - alpha_r = *one_r; \ - } \ -\ -\ - /* lower: b11.r = alpha.r * b11.r - ( a10.r * b01.r - a10.i * b01.i ); - b11.i = alpha.r * b11.i - ( a10.r * b01.i + a10.i * b01.r ); - - upper: b11.r = alpha.r * b11.r - ( a12.r * b21.r - a12.i * b21.i ); - b11.i = alpha.r * b11.i - ( a12.r * b21.i + a12.i * b21.r ); */ \ -\ - bli_auxinfo_set_next_ab( a1x_r, bx1_i, data ); \ -\ - /* lower: b11.r = alpha.r * b11.r - a10.r * b01.r; - upper: b11.r = alpha.r * b11.r - a12.r * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_r, \ - bx1_r, \ - &alpha_r, \ - b11_r, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_r, data ); \ -\ - /* lower: b11.i = alpha.r * b11.i - a10.r * b01.i; - upper: b11.i = alpha.r * b11.i - a12.r * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_r, \ - bx1_i, \ - &alpha_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a1x_i, bx1_i, data ); \ -\ - /* lower: b11.i = 1.0 * b11.i - a10.i * b01.r; - upper: b11.i = 1.0 * b11.i - a12.i * b21.r; */ \ - rgemm_ukr \ - ( \ - k, \ - minus_one_r, \ - a1x_i, \ - bx1_r, \ - one_r, \ - b11_i, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -\ - bli_auxinfo_set_next_ab( a_next, b_next, data ); \ -\ - /* lower: b11.r = 1.0 * b11.r + a10.i * b01.i; - upper: b11.r = 1.0 * b11.r + a12.i * b21.i; */ \ - rgemm_ukr \ - ( \ - k, \ - one_r, \ - a1x_i, \ - bx1_i, \ - onel_r, \ - b11_r, rs_b, cs_b, \ - data, \ - cntx \ - ); \ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r post-gemm", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i post-gemm", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - /* b11 = inv(a11) * b11; - c11 = b11; */ \ - ctrsm_vir_ukr \ - ( \ - a11, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_r after", k+m, n, \ - bx1_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "gemmtrsm4m1_l_ukr: bx111p_i after", k+m, n, \ - bx1_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC3( gemmtrsm4m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_L_UKR ) -INSERT_GENTFUNCCO_BASIC3( gemmtrsm4m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, BLIS_TRSM_U_UKR ) diff --git a/ref_kernels/ind/bli_trsm1m_ref.c b/ref_kernels/ind/bli_trsm1m_ref.c index a89d8b90d3..68717f7a6c 100644 --- a/ref_kernels/ind/bli_trsm1m_ref.c +++ b/ref_kernels/ind/bli_trsm1m_ref.c @@ -67,7 +67,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ const inc_t ld_a = cs_a; \ const inc_t ld_b = rs_b; \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ dim_t iter, i, j, l; \ dim_t n_behind; \ @@ -277,7 +277,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ const inc_t ld_a = cs_a; \ const inc_t ld_b = rs_b; \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ dim_t iter, i, j, l; \ dim_t n_behind; \ diff --git a/ref_kernels/ind/bli_trsm3m1_ref.c b/ref_kernels/ind/bli_trsm3m1_ref.c deleted file mode 100644 index c24c2f4e2a..0000000000 --- a/ref_kernels/ind/bli_trsm3m1_ref.c +++ /dev/null @@ -1,283 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_ri = ( ctype_r* )b + 2*is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = iter; \ - n_behind = i; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a10t_r = a_r + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict a10t_i = a_i + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_ri = b_ri + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_r = b_r + (0 )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_i = b_i + (0 )*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a10t * B0; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_ri = b1_ri + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_r = B0_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_i = B0_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a10t * b01; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha10_r = a10t_r + (l )*cs_a; \ - ctype_r* restrict alpha10_i = a10t_i + (l )*cs_a; \ - ctype_r* restrict beta01_r = b01_r + (l )*rs_b; \ - ctype_r* restrict beta01_i = b01_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha10_r, \ - *alpha10_i, \ - *beta01_r, \ - *beta01_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ -\ - /* Update the ri part of the packed panel. */ \ - PASTEMAC(chr,add3s)( beta11c_r, \ - beta11c_i, \ - *beta11_ri ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm3m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ - ctype_r* restrict b_ri = ( ctype_r* )b + 2*is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = m - iter - 1; \ - n_behind = iter; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a12t_r = a_r + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict a12t_i = a_i + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_ri = b_ri + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_r = b_r + (i+1)*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_i = b_i + (i+1)*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a12t * B2; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_ri = b1_ri + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_r = B2_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_i = B2_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a12t * b21; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha12_r = a12t_r + (l )*cs_a; \ - ctype_r* restrict alpha12_i = a12t_i + (l )*cs_a; \ - ctype_r* restrict beta21_r = b21_r + (l )*rs_b; \ - ctype_r* restrict beta21_i = b21_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha12_r, \ - *alpha12_i, \ - *beta21_r, \ - *beta21_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ -\ - /* Update the ri part of the packed panel. */ \ - PASTEMAC(chr,add3s)( beta11c_r, \ - beta11c_i, \ - *beta11_ri ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm3m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/ref_kernels/ind/bli_trsm4m1_ref.c b/ref_kernels/ind/bli_trsm4m1_ref.c deleted file mode 100644 index 81d203e403..0000000000 --- a/ref_kernels/ind/bli_trsm4m1_ref.c +++ /dev/null @@ -1,284 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: a11p_r", m, m, \ - a_r, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: a11p_i", m, m, \ - a_i, 1, PASTEMAC(chr,packmr), "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_r", m, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_i", m, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = iter; \ - n_behind = i; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a10t_r = a_r + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict a10t_i = a_i + (i )*rs_a + (0 )*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_r = b_r + (0 )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B0_i = b_i + (0 )*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a10t * B0; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_r = B0_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b01_i = B0_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a10t * b01; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha10_r = a10t_r + (l )*cs_a; \ - ctype_r* restrict alpha10_i = a10t_i + (l )*cs_a; \ - ctype_r* restrict beta01_r = b01_r + (l )*rs_b; \ - ctype_r* restrict beta01_i = b01_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha10_r, \ - *alpha10_i, \ - *beta01_r, \ - *beta01_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ - } \ - } \ -\ -/* -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_r after", m, n, \ - b_r, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -PASTEMAC(chr,fprintm)( stdout, "trsm4m1_l_ukr: b11p_i after", m, n, \ - b_i, PASTEMAC(chr,packnr), 1, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm4m1_l, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - - -#undef GENTFUNCCO -#define GENTFUNCCO( ctype, ctype_r, ch, chr, opname, arch, suf ) \ -\ -void PASTEMAC3(ch,opname,arch,suf) \ - ( \ - ctype* restrict a, \ - ctype* restrict b, \ - ctype* restrict c, inc_t rs_c, inc_t cs_c, \ - auxinfo_t* restrict data, \ - cntx_t* restrict cntx \ - ) \ -{ \ - const num_t dt_r = PASTEMAC(chr,type); \ -\ - const dim_t mr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ - const dim_t nr = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ -\ - const inc_t packmr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_MR, cntx ); \ - const inc_t packnr = bli_cntx_get_blksz_max_dt( dt_r, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ -\ - const inc_t is_a = bli_auxinfo_is_a( data ); \ - const inc_t is_b = bli_auxinfo_is_b( data ); \ -\ - ctype_r* restrict a_r = ( ctype_r* )a; \ - ctype_r* restrict a_i = ( ctype_r* )a + is_a; \ -\ - ctype_r* restrict b_r = ( ctype_r* )b; \ - ctype_r* restrict b_i = ( ctype_r* )b + is_b; \ -\ - const inc_t rs_a = 1; \ - const inc_t cs_a = packmr; \ -\ - const inc_t rs_b = packnr; \ - const inc_t cs_b = 1; \ -\ - dim_t iter, i, j, l; \ - dim_t n_behind; \ -\ -\ - for ( iter = 0; iter < m; ++iter ) \ - { \ - i = m - iter - 1; \ - n_behind = iter; \ -\ - ctype_r* restrict alpha11_r = a_r + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict alpha11_i = a_i + (i )*rs_a + (i )*cs_a; \ - ctype_r* restrict a12t_r = a_r + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict a12t_i = a_i + (i )*rs_a + (i+1)*cs_a; \ - ctype_r* restrict b1_r = b_r + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict b1_i = b_i + (i )*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_r = b_r + (i+1)*rs_b + (0 )*cs_b; \ - ctype_r* restrict B2_i = b_i + (i+1)*rs_b + (0 )*cs_b; \ -\ - /* b1 = b1 - a12t * B2; */ \ - /* b1 = b1 / alpha11; */ \ - for ( j = 0; j < n; ++j ) \ - { \ - ctype_r* restrict beta11_r = b1_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict beta11_i = b1_i + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_r = B2_r + (0 )*rs_b + (j )*cs_b; \ - ctype_r* restrict b21_i = B2_i + (0 )*rs_b + (j )*cs_b; \ - ctype* restrict gamma11 = c + (i )*rs_c + (j )*cs_c; \ - ctype_r beta11c_r = *beta11_r; \ - ctype_r beta11c_i = *beta11_i; \ - ctype_r rho11_r; \ - ctype_r rho11_i; \ -\ - /* beta11 = beta11 - a12t * b21; */ \ - PASTEMAC(chr,set0s)( rho11_r ); \ - PASTEMAC(chr,set0s)( rho11_i ); \ - for ( l = 0; l < n_behind; ++l ) \ - { \ - ctype_r* restrict alpha12_r = a12t_r + (l )*cs_a; \ - ctype_r* restrict alpha12_i = a12t_i + (l )*cs_a; \ - ctype_r* restrict beta21_r = b21_r + (l )*rs_b; \ - ctype_r* restrict beta21_i = b21_i + (l )*rs_b; \ -\ - PASTEMAC(ch,axpyris)( *alpha12_r, \ - *alpha12_i, \ - *beta21_r, \ - *beta21_i, \ - rho11_r, \ - rho11_i ); \ - } \ - PASTEMAC(ch,subris)( rho11_r, \ - rho11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* beta11 = beta11 / alpha11; */ \ - /* NOTE: The INVERSE of alpha11 (1.0/alpha11) is stored instead - of alpha11, so we can multiply rather than divide. We store - the inverse of alpha11 intentionally to avoid expensive - division instructions within the micro-kernel. */ \ - PASTEMAC(ch,scalris)( *alpha11_r, \ - *alpha11_i, \ - beta11c_r, \ - beta11c_i ); \ -\ - /* Output final result to matrix c. */ \ - PASTEMAC(ch,sets)( beta11c_r, \ - beta11c_i, *gamma11 ); \ -\ - /* Store the local values back to b11. */ \ - PASTEMAC(chr,copys)( beta11c_r, *beta11_r ); \ - PASTEMAC(chr,copys)( beta11c_i, *beta11_i ); \ - } \ - } \ -} - -INSERT_GENTFUNCCO_BASIC2( trsm4m1_u, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) diff --git a/frame/3/gemm/ind/old/bli_gemm3m3_packa.c b/sandbox/gemmlike/bli_gemm_ex.c similarity index 51% rename from frame/3/gemm/ind/old/bli_gemm3m3_packa.c rename to sandbox/gemmlike/bli_gemm_ex.c index 24d575c814..96dae1a3a9 100644 --- a/frame/3/gemm/ind/old/bli_gemm3m3_packa.c +++ b/sandbox/gemmlike/bli_gemm_ex.c @@ -32,111 +32,60 @@ */ +// Given the current architecture of BLIS sandboxes, bli_gemm_ex() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented functionally identically to the +// function that it overrides in frame/3/bli_l3_oapi_ex.c. This means that +// we are forgoing the option of customizing the implementations that +// underlie bli_gemm() and bli_?gemm() (which both call bli_gemm_ex()). +// Any new code defined in this sandbox directory, however, will be +// included in the BLIS. + #include "blis.h" -void bli_gemm3m3_packa +void bli_gemm_ex ( + obj_t* alpha, obj_t* a, obj_t* b, + obj_t* beta, obj_t* c, cntx_t* cntx, - cntl_t* cntl, - thrinfo_t* thread + rntm_t* rntm ) { - obj_t a_pack; - - // Make a copy of the context for each stage. - cntx_t cntx_ro = *cntx; - cntx_t cntx_io = *cntx; - cntx_t cntx_rpi = *cntx; - - // ----------------------------------------------------- - - // Initialize the context for the real-only stage. - bli_gemm3m3_cntx_stage( 0, &cntx_ro ); - - // Pack matrix the real-only part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_ro, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int + bli_init_once(); + + // A switch to easily toggle whether we use the sandbox implementation + // of bls_gemm() as the implementation for bli_gemm(). (This allows for + // easy testing of bls_gemm() via the testsuite.) Changing the conditional + // to "0" will cause bli_gemm()/bli_gemm_ex() to *not* call the local + // sandbox implementation, though that implementation may still be called + // directly. + if ( 1 ) + { + bls_gemm_ex( alpha, a, b, beta, c, cntx, rntm ); + return; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) + alpha, a, b, beta, c, cntx, rntm, NULL ); - - // Only apply beta within the first of three subproblems. - bli_obj_scalar_reset( c ); - - // ----------------------------------------------------- - - // Initialize the context for the imag-only stage. - bli_gemm3m3_cntx_stage( 1, &cntx_io ); - - // Pack matrix the imag-only part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_io, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int - ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); - - // ----------------------------------------------------- - - // Initialize the context for the real+imag stage. - bli_gemm3m3_cntx_stage( 2, &cntx_rpi ); - - // Pack matrix the real+imag part of A. - bli_l3_packm - ( - a, - &a_pack, - &cntx_rpi, - cntl, - thread - ); - - // Proceed with execution using packed matrix A. - bli_gemm_int - ( - &BLIS_ONE, - &a_pack, - b, - &BLIS_ONE, - c, - cntx, - bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread ) - ); - } diff --git a/sandbox/gemmlike/bls_gemm.c b/sandbox/gemmlike/bls_gemm.c index 4ee3a773f2..0b15f21970 100644 --- a/sandbox/gemmlike/bls_gemm.c +++ b/sandbox/gemmlike/bls_gemm.c @@ -72,18 +72,20 @@ void bls_gemm_ex { bli_init_once(); - // -- bli_gemmnat() -------------------------------------------------------- + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } // Obtain a valid (native) context from the gks if necessary. // NOTE: This must be done before calling the _check() function, since // that function assumes the context pointer is valid. if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_l; - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } - else { rntm_l = *rntm; rntm = &rntm_l; } + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bls_gemm_check( alpha, a, b, beta, c, cntx ); // -- bli_gemm_front() ----------------------------------------------------- @@ -91,12 +93,6 @@ void bls_gemm_ex obj_t b_local; obj_t c_local; - // Check parameters. - if ( bli_error_checking_is_enabled() ) - { - bls_gemm_check( alpha, a, b, beta, c, cntx ); - } - // If C has a zero dimension, return early. if ( bli_obj_has_zero_dim( c ) ) { @@ -145,11 +141,6 @@ void bls_gemm_ex bli_obj_induce_trans( &a_local ); bli_obj_induce_trans( &b_local ); bli_obj_induce_trans( &c_local ); - - // NOTE: This is probably not needed within the sandbox. - // We must also swap the pack schemas, which were set by bli_gemm_md() - // or the inlined code above. - //bli_obj_swap_pack_schemas( &a_local, &b_local ); } // Parse and interpret the contents of the rntm_t object to properly diff --git a/sandbox/ref99/bli_gemmnat.c b/sandbox/old/ref99/bli_gemmnat.c similarity index 95% rename from sandbox/ref99/bli_gemmnat.c rename to sandbox/old/ref99/bli_gemmnat.c index e180908fc3..eed9373a5a 100644 --- a/sandbox/ref99/bli_gemmnat.c +++ b/sandbox/old/ref99/bli_gemmnat.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -15,7 +15,7 @@ - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - - Neither the name of copyright holder(s) nor the names + - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/sandbox/ref99/bli_sandbox.h b/sandbox/old/ref99/bli_sandbox.h similarity index 100% rename from sandbox/ref99/bli_sandbox.h rename to sandbox/old/ref99/bli_sandbox.h diff --git a/sandbox/ref99/blix.h b/sandbox/old/ref99/blix.h similarity index 100% rename from sandbox/ref99/blix.h rename to sandbox/old/ref99/blix.h diff --git a/sandbox/ref99/blx_gemm_ref_var2.c b/sandbox/old/ref99/blx_gemm_ref_var2.c similarity index 100% rename from sandbox/ref99/blx_gemm_ref_var2.c rename to sandbox/old/ref99/blx_gemm_ref_var2.c diff --git a/sandbox/ref99/blx_gemm_ref_var2.h b/sandbox/old/ref99/blx_gemm_ref_var2.h similarity index 100% rename from sandbox/ref99/blx_gemm_ref_var2.h rename to sandbox/old/ref99/blx_gemm_ref_var2.h diff --git a/sandbox/ref99/old/base/blx_blksz.c b/sandbox/old/ref99/old/base/blx_blksz.c similarity index 100% rename from sandbox/ref99/old/base/blx_blksz.c rename to sandbox/old/ref99/old/base/blx_blksz.c diff --git a/sandbox/ref99/old/base/blx_blksz.h b/sandbox/old/ref99/old/base/blx_blksz.h similarity index 100% rename from sandbox/ref99/old/base/blx_blksz.h rename to sandbox/old/ref99/old/base/blx_blksz.h diff --git a/sandbox/ref99/old/blx_gemm.h b/sandbox/old/ref99/old/blx_gemm.h similarity index 100% rename from sandbox/ref99/old/blx_gemm.h rename to sandbox/old/ref99/old/blx_gemm.h diff --git a/sandbox/ref99/old/blx_gemm_front.c b/sandbox/old/ref99/old/blx_gemm_front.c similarity index 100% rename from sandbox/ref99/old/blx_gemm_front.c rename to sandbox/old/ref99/old/blx_gemm_front.c diff --git a/sandbox/ref99/old/blx_gemm_front.h b/sandbox/old/ref99/old/blx_gemm_front.h similarity index 100% rename from sandbox/ref99/old/blx_gemm_front.h rename to sandbox/old/ref99/old/blx_gemm_front.h diff --git a/sandbox/ref99/old/blx_gemm_int.c b/sandbox/old/ref99/old/blx_gemm_int.c similarity index 100% rename from sandbox/ref99/old/blx_gemm_int.c rename to sandbox/old/ref99/old/blx_gemm_int.c diff --git a/sandbox/ref99/old/blx_gemm_int.h b/sandbox/old/ref99/old/blx_gemm_int.h similarity index 100% rename from sandbox/ref99/old/blx_gemm_int.h rename to sandbox/old/ref99/old/blx_gemm_int.h diff --git a/sandbox/ref99/old/cntl/blx_gemm_cntl.c b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.c similarity index 100% rename from sandbox/ref99/old/cntl/blx_gemm_cntl.c rename to sandbox/old/ref99/old/cntl/blx_gemm_cntl.c diff --git a/sandbox/ref99/old/cntl/blx_gemm_cntl.h b/sandbox/old/ref99/old/cntl/blx_gemm_cntl.h similarity index 100% rename from sandbox/ref99/old/cntl/blx_gemm_cntl.h rename to sandbox/old/ref99/old/cntl/blx_gemm_cntl.h diff --git a/sandbox/ref99/old/cntl/blx_l3_cntl_if.c b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.c similarity index 100% rename from sandbox/ref99/old/cntl/blx_l3_cntl_if.c rename to sandbox/old/ref99/old/cntl/blx_l3_cntl_if.c diff --git a/sandbox/ref99/old/cntl/blx_l3_cntl_if.h b/sandbox/old/ref99/old/cntl/blx_l3_cntl_if.h similarity index 100% rename from sandbox/ref99/old/cntl/blx_l3_cntl_if.h rename to sandbox/old/ref99/old/cntl/blx_l3_cntl_if.h diff --git a/sandbox/ref99/old/cntl/blx_packm_cntl.c b/sandbox/old/ref99/old/cntl/blx_packm_cntl.c similarity index 100% rename from sandbox/ref99/old/cntl/blx_packm_cntl.c rename to sandbox/old/ref99/old/cntl/blx_packm_cntl.c diff --git a/sandbox/ref99/old/cntl/blx_packm_cntl.h b/sandbox/old/ref99/old/cntl/blx_packm_cntl.h similarity index 100% rename from sandbox/ref99/old/cntl/blx_packm_cntl.h rename to sandbox/old/ref99/old/cntl/blx_packm_cntl.h diff --git a/sandbox/ref99/old/packm/blx_l3_packm.c b/sandbox/old/ref99/old/packm/blx_l3_packm.c similarity index 100% rename from sandbox/ref99/old/packm/blx_l3_packm.c rename to sandbox/old/ref99/old/packm/blx_l3_packm.c diff --git a/sandbox/ref99/old/packm/blx_l3_packm.h b/sandbox/old/ref99/old/packm/blx_l3_packm.h similarity index 100% rename from sandbox/ref99/old/packm/blx_l3_packm.h rename to sandbox/old/ref99/old/packm/blx_l3_packm.h diff --git a/sandbox/ref99/old/thread/blx_gemm_thread.c b/sandbox/old/ref99/old/thread/blx_gemm_thread.c similarity index 100% rename from sandbox/ref99/old/thread/blx_gemm_thread.c rename to sandbox/old/ref99/old/thread/blx_gemm_thread.c diff --git a/sandbox/ref99/old/thread/blx_gemm_thread.h b/sandbox/old/ref99/old/thread/blx_gemm_thread.h similarity index 100% rename from sandbox/ref99/old/thread/blx_gemm_thread.h rename to sandbox/old/ref99/old/thread/blx_gemm_thread.h diff --git a/sandbox/ref99/old/vars/blx_gemm_blk_var1.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var1.c similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_blk_var1.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var1.c diff --git a/sandbox/ref99/old/vars/blx_gemm_blk_var2.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var2.c similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_blk_var2.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var2.c diff --git a/sandbox/ref99/old/vars/blx_gemm_blk_var3.c b/sandbox/old/ref99/old/vars/blx_gemm_blk_var3.c similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_blk_var3.c rename to sandbox/old/ref99/old/vars/blx_gemm_blk_var3.c diff --git a/sandbox/ref99/old/vars/blx_gemm_ker_var2.c b/sandbox/old/ref99/old/vars/blx_gemm_ker_var2.c similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_ker_var2.c rename to sandbox/old/ref99/old/vars/blx_gemm_ker_var2.c diff --git a/sandbox/ref99/old/vars/blx_gemm_packab.c b/sandbox/old/ref99/old/vars/blx_gemm_packab.c similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_packab.c rename to sandbox/old/ref99/old/vars/blx_gemm_packab.c diff --git a/sandbox/ref99/old/vars/blx_gemm_var.h b/sandbox/old/ref99/old/vars/blx_gemm_var.h similarity index 100% rename from sandbox/ref99/old/vars/blx_gemm_var.h rename to sandbox/old/ref99/old/vars/blx_gemm_var.h diff --git a/sandbox/ref99/old/vars/other/blx_gemm_ker_var2rr.c b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2rr.c similarity index 100% rename from sandbox/ref99/old/vars/other/blx_gemm_ker_var2rr.c rename to sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2rr.c diff --git a/sandbox/ref99/old/vars/other/blx_gemm_ker_var2sl.c b/sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2sl.c similarity index 100% rename from sandbox/ref99/old/vars/other/blx_gemm_ker_var2sl.c rename to sandbox/old/ref99/old/vars/other/blx_gemm_ker_var2sl.c diff --git a/sandbox/power10/bli_gemmnat.c b/sandbox/power10/bli_gemm_ex.c similarity index 61% rename from sandbox/power10/bli_gemmnat.c rename to sandbox/power10/bli_gemm_ex.c index 846ccd35a8..3334dc4a53 100644 --- a/sandbox/power10/bli_gemmnat.c +++ b/sandbox/power10/bli_gemm_ex.c @@ -32,47 +32,48 @@ */ -// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the +// Given the current architecture of BLIS sandboxes, bli_gemm_ex() is the // entry point to any sandbox implementation. -// NOTE: This function is implemented identically to the function that it -// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are -// forgoing the option of customizing the implementations that underlie -// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox -// directory, however, will be included in the BLIS. +// NOTE: This function is implemented functionally identically to the +// function that it overrides in frame/3/bli_l3_oapi_ex.c. This means that +// we are forgoing the option of customizing the implementations that +// underlie bli_gemm() and bli_?gemm() (which both call bli_gemm_ex()). +// Any new code defined in this sandbox directory, however, will be +// included in the BLIS. #include "blis.h" -#undef GENFRONT -#define GENFRONT( opname, cname, imeth ) \ -\ -void PASTEMAC(opname,imeth) \ - ( \ - obj_t* alpha, \ - obj_t* a, \ - obj_t* b, \ - obj_t* beta, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ - bli_init_once(); \ -\ - /* Obtain a valid (native) context from the gks if necessary. */ \ - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ -\ - /* Initialize a local runtime with global settings if necessary. Note - that in the case that a runtime is passed in, we make a local copy. */ \ - rntm_t rntm_l; \ - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ - else { rntm_l = *rntm; rntm = &rntm_l; } \ -\ - /* Invoke the operation's front end. */ \ - PASTEMAC(opname,_front) \ - ( \ - alpha, a, b, beta, c, cntx, rntm, NULL \ - ); \ +void bli_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front + ( + alpha, a, b, beta, c, cntx, rntm, NULL + ); } -GENFRONT( gemm, gemm, nat ) diff --git a/so_version b/so_version index 549d6b8284..93c505bc70 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ -4 -2.0 +5 +0.0 diff --git a/test/1m4m/Makefile b/test/1m4m/Makefile index 86ea82eb77..3d4cf5ebf4 100644 --- a/test/1m4m/Makefile +++ b/test/1m4m/Makefile @@ -1,11 +1,11 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -186,20 +186,10 @@ BLA_DEF := -DBLAS EIG_DEF := -DEIGEN # Complex implementation type -D3MHW := -DIND=BLIS_3MH -D3M1 := -DIND=BLIS_3M1 -D4MHW := -DIND=BLIS_4MH -D4M1B := -DIND=BLIS_4M1B -D4M1A := -DIND=BLIS_4M1A D1M := -DIND=BLIS_1M DNAT := -DIND=BLIS_NAT # Implementation string -#STR_3MHW := -DSTR=\"3mhw\" -#STR_3M1 := -DSTR=\"3m1\" -#STR_4MHW := -DSTR=\"4mhw\" -#STR_4M1B := -DSTR=\"4m1b\" -STR_4M1A := -DSTR=\"4m1a_blis\" STR_1M := -DSTR=\"1m_blis\" STR_NAT := -DSTR=\"asm_blis\" STR_OBL := -DSTR=\"openblas\" @@ -234,19 +224,18 @@ all-st: blis-st openblas-st mkl-st all-1s: blis-1s openblas-1s mkl-1s all-2s: blis-2s openblas-2s mkl-2s -blis-st: blis-nat-st blis-1m-st blis-4m1a-st -blis-1s: blis-nat-1s blis-1m-1s blis-4m1a-1s -blis-2s: blis-nat-2s blis-1m-2s blis-4m1a-2s +blis-st: blis-nat-st blis-1m-st +blis-1s: blis-nat-1s blis-1m-1s +blis-2s: blis-nat-2s blis-1m-2s #blis-ind: blis-ind-st blis-ind-mt blis-nat: blis-nat-st blis-nat-1s blis-nat-2s blis-1m: blis-1m-st blis-1m-1s blis-1m-2s -blis-4m1a: blis-4m1a-st blis-4m1a-1s blis-4m1a-2s # Define the datatypes, operations, and implementations. DTS := s d c z OPS := gemm -BIMPLS := asm_blis 4m1a_blis 1m_blis openblas vendor +BIMPLS := asm_blis 1m_blis openblas vendor EIMPLS := eigen # Define functions to construct object filenames from the datatypes and @@ -265,13 +254,6 @@ BLIS_1M_1S_BINS := $(patsubst %.o,%.x,$(BLIS_1M_1S_OBJS)) BLIS_1M_2S_OBJS := $(call get-2s-objs,1m_blis) BLIS_1M_2S_BINS := $(patsubst %.o,%.x,$(BLIS_1M_2S_OBJS)) -BLIS_4M1A_ST_OBJS := $(call get-st-objs,4m1a_blis) -BLIS_4M1A_ST_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_ST_OBJS)) -BLIS_4M1A_1S_OBJS := $(call get-1s-objs,4m1a_blis) -BLIS_4M1A_1S_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_1S_OBJS)) -BLIS_4M1A_2S_OBJS := $(call get-2s-objs,4m1a_blis) -BLIS_4M1A_2S_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_2S_OBJS)) - BLIS_NAT_ST_OBJS := $(call get-st-objs,asm_blis) BLIS_NAT_ST_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_ST_OBJS)) BLIS_NAT_1S_OBJS := $(call get-1s-objs,asm_blis) @@ -309,10 +291,6 @@ blis-1m-st: $(BLIS_1M_ST_BINS) blis-1m-1s: $(BLIS_1M_1S_BINS) blis-1m-2s: $(BLIS_1M_2S_BINS) -blis-4m1a-st: $(BLIS_4M1A_ST_BINS) -blis-4m1a-1s: $(BLIS_4M1A_1S_BINS) -blis-4m1a-2s: $(BLIS_4M1A_2S_BINS) - openblas-st: $(OPENBLAS_ST_BINS) openblas-1s: $(OPENBLAS_1S_BINS) openblas-2s: $(OPENBLAS_2S_BINS) @@ -337,7 +315,6 @@ armpl-2s: vendor-2s # automatically after building the binaries on which they depend. .INTERMEDIATE: $(BLIS_NAT_ST_OBJS) $(BLIS_NAT_1S_OBJS) $(BLIS_NAT_2S_OBJS) .INTERMEDIATE: $(BLIS_1M_ST_OBJS) $(BLIS_1M_1S_OBJS) $(BLIS_1M_2S_OBJS) -.INTERMEDIATE: $(BLIS_4M1A_ST_OBJS) $(BLIS_4M1A_1S_OBJS) $(BLIS_4M1A_2S_OBJS) .INTERMEDIATE: $(OPENBLAS_ST_OBJS) $(OPENBLAS_1S_OBJS) $(OPENBLAS_2S_OBJS) .INTERMEDIATE: $(EIGEN_ST_OBJS) $(EIGEN_1S_OBJS) $(EIGEN_2S_OBJS) .INTERMEDIATE: $(VENDOR_ST_OBJS) $(VENDOR_1S_OBJS) $(VENDOR_2S_OBJS) @@ -358,8 +335,7 @@ get-dt-cpp = $(strip \ get-in-cpp = $(strip \ $(if $(findstring 1m_blis,$(1)),-DIND=BLIS_1M,\ - $(if $(findstring 4m1a_blis,$(1)),-DIND=BLIS_4M1A,\ - -DIND=BLIS_NAT))) + -DIND=BLIS_NAT)) # A function to return other cpp macros that help the test driver # identify the implementation. @@ -371,7 +347,6 @@ get-in-cpp = $(strip \ get-bl-cpp = $(strip \ $(if $(findstring 1m_blis,$(1)),$(STR_1M) $(BLI_DEF),\ - $(if $(findstring 4m1a_blis,$(1)),$(STR_4M1A) $(BLI_DEF),\ $(if $(findstring asm_blis,$(1)),$(STR_NAT) $(BLI_DEF),\ $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ $(if $(and $(findstring eigen,$(1)),\ @@ -379,7 +354,7 @@ get-bl-cpp = $(strip \ $(STR_EIG) $(EIG_DEF),\ $(if $(findstring eigen,$(1)),\ $(STR_EIG) $(BLA_DEF),\ - $(STR_VEN) $(BLA_DEF)))))))) + $(STR_VEN) $(BLA_DEF))))))) # Rules for BLIS and BLAS libraries. @@ -456,16 +431,6 @@ test_%_$(P2_MAX)_1m_blis_2s.x: test_%_$(P2_MAX)_1m_blis_2s.o $(LIBBLIS_LINK) $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) -test_%_$(PS_MAX)_4m1a_blis_st.x: test_%_$(PS_MAX)_4m1a_blis_st.o $(LIBBLIS_LINK) - $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) - -test_%_$(P1_MAX)_4m1a_blis_1s.x: test_%_$(P1_MAX)_4m1a_blis_1s.o $(LIBBLIS_LINK) - $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) - -test_%_$(P2_MAX)_4m1a_blis_2s.x: test_%_$(P2_MAX)_4m1a_blis_2s.o $(LIBBLIS_LINK) - $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) - - test_%_$(PS_MAX)_asm_blis_st.x: test_%_$(PS_MAX)_asm_blis_st.o $(LIBBLIS_LINK) $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) diff --git a/test/1m4m/runme.sh b/test/1m4m/runme.sh index 881cf4776d..38236f64a7 100755 --- a/test/1m4m/runme.sh +++ b/test/1m4m/runme.sh @@ -80,11 +80,10 @@ test_dts="s d c z" test_ops="gemm" # Implementations to test. -#test_impls="openblas vendor asm_blis 1m_blis 4m1a_blis" -#test_impls="asm_blis 1m_blis 4m1a_blis" +#test_impls="openblas vendor asm_blis 1m_blis" +#test_impls="asm_blis 1m_blis" #test_impls="asm_blis" -#test_impls="4m1a_blis" -test_impls="asm_blis 4m1a_blis 1m_blis" +test_impls="asm_blis 1m_blis" # Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can # restore the value. @@ -136,7 +135,7 @@ for th in ${threads}; do for im in ${test_impls}; do if [ "${dt}" = "s" -o "${dt}" = "d" ] && \ - [ "${im}" = "1m_blis" -o "${im}" = "4m1a_blis" ]; then + [ "${im}" = "1m_blis" ]; then continue fi @@ -164,8 +163,7 @@ for th in ${threads}; do # Set the threading parameters based on the implementation # that we are preparing to run. if [ "${im}" = "asm_blis" ] || \ - [ "${im}" = "1m_blis" ] || \ - [ "${im}" = "4m1a_blis" ]; then + [ "${im}" = "1m_blis" ]; then unset OMP_NUM_THREADS export BLIS_JC_NT=${jc_nt} export BLIS_PC_NT=${pc_nt} diff --git a/test/1m4m/test_gemm.c b/test/1m4m/test_gemm.c index a58e6e5893..f9a855125f 100644 --- a/test/1m4m/test_gemm.c +++ b/test/1m4m/test_gemm.c @@ -108,9 +108,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); @@ -120,8 +117,7 @@ int main( int argc, char** argv ) #elif 0 #ifdef BLIS - if ( ind == BLIS_4M1A ) k_input = 128; - else if ( ind == BLIS_1M ) k_input = 128; + if ( ind == BLIS_1M ) k_input = 128; else k_input = 256; #else k_input = 192; diff --git a/test/3/Makefile b/test/3/Makefile index 2d865c19e7..274eb2105e 100644 --- a/test/3/Makefile +++ b/test/3/Makefile @@ -1,11 +1,11 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -195,20 +195,10 @@ BLA_DEF := -DBLAS EIG_DEF := -DEIGEN # Complex implementation type -D3MHW := -DIND=BLIS_3MH -D3M1 := -DIND=BLIS_3M1 -D4MHW := -DIND=BLIS_4MH -D4M1B := -DIND=BLIS_4M1B -D4M1A := -DIND=BLIS_4M1A D1M := -DIND=BLIS_1M DNAT := -DIND=BLIS_NAT # Implementation string -#STR_3MHW := -DSTR=\"3mhw\" -#STR_3M1 := -DSTR=\"3m1\" -#STR_4MHW := -DSTR=\"4mhw\" -#STR_4M1B := -DSTR=\"4m1b\" -#STR_4M1A := -DSTR=\"4m1a\" #STR_1M := -DSTR=\"1m\" STR_NAT := -DSTR=\"asm_blis\" STR_OBL := -DSTR=\"openblas\" diff --git a/test/3/test_gemm.c b/test/3/test_gemm.c index eb1a732648..344a08cd78 100644 --- a/test/3/test_gemm.c +++ b/test/3/test_gemm.c @@ -108,9 +108,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); diff --git a/test/3/test_hemm.c b/test/3/test_hemm.c index e69a1ec574..8df46f0f01 100644 --- a/test/3/test_hemm.c +++ b/test/3/test_hemm.c @@ -86,9 +86,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); diff --git a/test/3/test_herk.c b/test/3/test_herk.c index d2b51f9d4e..ab8686622c 100644 --- a/test/3/test_herk.c +++ b/test/3/test_herk.c @@ -88,9 +88,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); diff --git a/test/3/test_trmm.c b/test/3/test_trmm.c index e70330ee52..4cd84e0ab8 100644 --- a/test/3/test_trmm.c +++ b/test/3/test_trmm.c @@ -91,9 +91,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); diff --git a/test/3/test_trsm.c b/test/3/test_trsm.c index 8078dc0ad6..4e449ff22b 100644 --- a/test/3/test_trsm.c +++ b/test/3/test_trsm.c @@ -91,9 +91,6 @@ int main( int argc, char** argv ) ind_t ind_mod = ind; - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - // Initialize a context for the current induced method and datatype. cntx = bli_gks_query_ind_cntx( ind_mod, dt ); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d116e942d0..bbf30fc963 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,173 +1,142 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## - -add_definitions(-DBLAS="AOCL") - -add_executable(TestAminv test_aminv.c) -target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAminv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") - -add_executable(TestAxpyv test_axpyv.c) -target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") - -add_executable(TestAxpbyv test_axpbyv.c) -target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") - -add_executable(TestCopyv test_copyv.c) -target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") - -add_executable(TestCabs1 test_cabs1.c) -target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") - -add_executable(TestDotv test_dotv.c) -target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestDotv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") - -add_executable(TestGemm test_gemm.c) -target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") - -add_executable(TestGemmBatch test_gemm_batch.c) -target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") - -add_executable(TestGemm3m test_gemm3m.c) -target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") - -add_executable(TestGemmt test_gemmt.c) -target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") - -add_executable(TestGemv test_gemv.c) -target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") - -add_executable(TestGer test_ger.c) -target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGer OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestGer optimized "${LIB_NAME}.lib") - -add_executable(TestHemm test_hemm.c) -target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemm OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") - -add_executable(TestHemv test_hemv.c) -target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") - -add_executable(TestHer test_her.c) -target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHer optimized "${LIB_NAME}.lib") - -add_executable(TestHer2 test_her2.c) -target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") - -add_executable(TestHer2k test_her2k.c) -target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") - -add_executable(TestHerk test_herk.c) -target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHerk OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") - -add_executable(TestScalv test_scalv.c) -target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestScalv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") - -add_executable(TestSwapv test_swapv.c) -target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") - -add_executable(TestTrmm test_trmm.c) -target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") - -add_executable(TestTrmv test_trmv.c) -target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") - -add_executable(TestTrsm test_trsm.c) -target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") - -add_executable(TestTrsv test_trsv.c) -target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) -endif() -target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") - - +##Copyright (C) 2022-2024, Advanced Micro Devices, Inc. All rights reserved.## +# Comments: +# Set the path to the BLIS installation. +set(BLIS_INSTALL_PATH "" CACHE STRING "Setting the path to a BLIS installation that needs testing.") +if(BLIS_INSTALL_PATH) + message(STATUS "BLIS_INSTALL_PATH :" ${BLIS_INSTALL_PATH}) +endif() + +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. + +#if(NOT DEFINED BLIS_INSTALL_PATH) +if(BLIS_INSTALL_PATH STREQUAL "") + set(DIST_PATH ${CMAKE_BINARY_DIR}) + set(LIB_PATH ${DIST_PATH}/lib/${BLIS_CONFIG_FAMILY}) + set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) + set(CINFLAGS ${INC_PATH}) + set(LIBBLIS ${libblis_link}) +else() + set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) + set(INC_PATH ${BLIS_INSTALL_PATH}/include) + set(CINFLAGS ${INC_PATH}) + # Set up the library name. + if(WIN32) + set(LIB_BLIS AOCL-LibBlis-Win) + else() + set(LIB_BLIS ${libblis_link}) + endif() + # Append if threading is required. + if(NOT (ENABLE_THREADING STREQUAL "no")) + if(WIN32) + string(APPEND LIB_BLIS -MT) + else() + string(APPEND LIB_BLIS -mt) + endif() + endif() + # Append for dll if necessary. + if(WIN32 AND BUILD_SHARED_LIBS) + string(APPEND LIB_BLIS -dll) + endif() + # Setting the suffix for find_library(). + if(WIN32) + string(APPEND LIB_BLIS .lib) + else() + if(BUILD_SHARED_LIBS) + string(APPEND LIB_BLIS .so) + else() + string(APPEND LIB_BLIS .a) + endif() + endif() + set(LIBBLIS ${LIB_PATH}/${LIB_BLIS}) + message(STATUS "BLIS_INSTALL_PATH : " ${LIBBLIS}) +endif() + +if(WIN32) + set(LIBSUFFIX lib) +else() + set(LIBSUFFIX so) +endif() +set(CMAKE_EXECUTABLE_SUFFIX ".x") +set(MKL_PATH $ENV{MKLROOT} CACHE STRING "Set MKL_PATH.") +if(WIN32) + set(mkllib "${MKL_PATH}\\mkl_rt.lib" CACHE STRING "Set MKL_PATH.") +else() + set(mkllib "${MKL_PATH}/libmkl_rt.so" CACHE STRING "Set MKL_PATH.") +endif() +set(MKL_LIB ${mkllib}) +set(OPENBLAS_PATH "/home/amd/mylibs/openblas" CACHE STRING "Set OPENBLAS_PATH.") +set(OPENBLAS_LIB "${OPENBLAS_PATH}/libopenblas.${LIBSUFFIX}") + + +# Include the corresponding make_defs.cmake that holds the required compiler options. +include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) + +# Gather all local source files. +file(GLOB file_list LIST_DIRECTORIES false RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/" "*.c") + +# Create an executable using the sources above. +function(testexe extn) + set(dblas "aocl") + if(extn STREQUAL "mkl") + set(BLAS_LIBS ${MKL_LIB}) + set(dblas ${extn}) + elseif(extn STREQUAL "openblas") + set(BLAS_LIBS ${OPENBLAS_LIB}) + set(dblas ${extn}) + endif() + set(TEST_FLAGS -DBLAS="${dblas}") + foreach(src ${file_list}) + string(REGEX REPLACE ".c$" "" exec_name ${src}) + set(exec_name "${exec_name}_${extn}") + add_executable(${exec_name} ${src}) + target_compile_options(${exec_name} + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + ) + if(WIN32 AND BUILD_SHARED_LIBS) + target_compile_definitions(${exec_name} + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + "-DBLIS_EXPORT=__declspec(dllimport)" + ${TEST_FLAGS} + ) + else() + target_compile_definitions(${exec_name} + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + ${TEST_FLAGS} + ) + endif() + target_include_directories(${exec_name} + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + target_link_libraries(${exec_name} PRIVATE ${BLAS_LIBS} ${LIBBLIS} ${LDFLAGS}) + if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(${exec_name} PRIVATE OpenMP::OpenMP_C) + endif() + list(APPEND temp_executables ${exec_name}) + endforeach() + set(test_executables ${temp_executables} PARENT_SCOPE) +endfunction() + +testexe("blas") +add_custom_target(test_blis DEPENDS ${test_executables}) +testexe("mkl") +add_custom_target(test_mkl DEPENDS ${test_executables}) +testexe("openblas") +add_custom_target(test_openblas DEPENDS ${test_executables}) +add_custom_target(testall DEPENDS test_blis test_mkl test_openblas) + +# Put all those targets under test-targets folder name so that they appear all together in IDE. +set_target_properties(testall test_blis test_mkl test_openblas PROPERTIES FOLDER test-targets) \ No newline at end of file diff --git a/test/Makefile b/test/Makefile index 5d1958b876..ee6540d000 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,4 +1,3 @@ - # # # BLIS @@ -6,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/test/exec_sizes/Makefile b/test/exec_sizes/Makefile index eefc899186..c11d9d7995 100644 --- a/test/exec_sizes/Makefile +++ b/test/exec_sizes/Makefile @@ -1,4 +1,4 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like diff --git a/test/mixeddt/Makefile b/test/mixeddt/Makefile index 20e5378ffb..8e6e055277 100644 --- a/test/mixeddt/Makefile +++ b/test/mixeddt/Makefile @@ -1,4 +1,4 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like diff --git a/test/studies/skx/Makefile b/test/studies/skx/Makefile index 18a82c0ea2..d8d5a43ce8 100644 --- a/test/studies/skx/Makefile +++ b/test/studies/skx/Makefile @@ -1,4 +1,4 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like diff --git a/test/studies/thunderx2/Makefile b/test/studies/thunderx2/Makefile index ba45ebbe4d..50dbc0ffed 100644 --- a/test/studies/thunderx2/Makefile +++ b/test/studies/thunderx2/Makefile @@ -1,4 +1,4 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like diff --git a/test/sup/Makefile b/test/sup/Makefile index 6ee9f3ed1b..d1359359b1 100644 --- a/test/sup/Makefile +++ b/test/sup/Makefile @@ -1,11 +1,11 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/test/sup/old/supmt/Makefile b/test/sup/old/supmt/Makefile index ad12b83e1a..0d77ed1d42 100644 --- a/test/sup/old/supmt/Makefile +++ b/test/sup/old/supmt/Makefile @@ -1,11 +1,11 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/test/sup/old/supst/Makefile b/test/sup/old/supst/Makefile index c3eb0b5317..991618b99e 100644 --- a/test/sup/old/supst/Makefile +++ b/test/sup/old/supst/Makefile @@ -1,11 +1,11 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/test/thread_ranges/Makefile b/test/thread_ranges/Makefile index 5af2ce533c..adb6c9f438 100644 --- a/test/thread_ranges/Makefile +++ b/test/thread_ranges/Makefile @@ -1,4 +1,4 @@ -#!/bin/bash +# # # BLIS # An object-based framework for developing high-performance BLAS-like diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index c0f96e621d..577aaec1ed 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Comments: # - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. @@ -10,7 +42,7 @@ if(NOT DEFINED BLIS_INSTALL_PATH) set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) else() set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) - set(INC_PATH ${BLIS_INSTALL_PATH}/include/blis) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/${BLIS_CONFIG_FAMILY}) endif() # Include the corresponding make_defs.cmake that holds the required compiler options. @@ -23,6 +55,13 @@ file(GLOB testsuite_sources LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/s # get-user-cflags-for() is not cluttered up with include paths needed only # while building BLIS. set(CINFLAGS ${INC_PATH}) +if((NOT WIN32) AND ENABLE_COVERAGE) + include(coverage.cmake) + list(APPEND LDFLAGS ${COVERAGE_FLAGS}) +endif() +if((NOT WIN32) AND ENABLE_ASAN) + list(APPEND LDFLAGS ${ASAN_FLAGS}) +endif() # Create an executable using the sources above. add_executable(test_libblis.x ${testsuite_sources}) @@ -61,16 +100,20 @@ target_include_directories(test_libblis.x # Add local header paths ${CMAKE_CURRENT_SOURCE_DIR}/src ) -target_link_libraries(test_libblis.x PRIVATE ${LDFLAGS} libblis) +target_link_libraries(test_libblis.x PRIVATE ${libblis_link} ${LDFLAGS}) if(THREADING_MODEL STREQUAL "openmp") - target_link_libraries(test_libblis.x PRIVATE OpenMP::OpenMP_C) + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(test_libblis.x PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(test_libblis.x PRIVATE OpenMP::OpenMP_C) + endif() endif() # -- Test run/check rules -- # Wrap the creation of testing helpers in this function. function(add_testblis flavour) if (NOT(flavour STREQUAL "")) - set(dotflavour .${flavour}) + set(dotflavour .${flavour}) set(dashflavour -${flavour}) set(printflavour "(${flavour})") endif() @@ -80,14 +123,14 @@ function(add_testblis flavour) COMMENT "Running test_libblis.x ${printflavour} with output redirected to ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour}" DEPENDS test_libblis.x ${CMAKE_CURRENT_SOURCE_DIR}/input.general${dotflavour} ${CMAKE_CURRENT_SOURCE_DIR}/input.operations${dotflavour} BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour} - WORKING_DIRECTORY $ + WORKING_DIRECTORY $ VERBATIM ) # Check the results of the BLIS testsuite. add_custom_target(checkblis${dashflavour} COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/check-blistest.py ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour} DEPENDS testblis${dashflavour} - ) + ) endfunction() # Add testing targets using functions above for all input file options. diff --git a/testsuite/coverage.cmake b/testsuite/coverage.cmake new file mode 100644 index 0000000000..eb5fb39315 --- /dev/null +++ b/testsuite/coverage.cmake @@ -0,0 +1,84 @@ +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] + +# Comments: + +find_program(LCOV NAMES lcov HINTS "/usr" PATH_SUFFIXES "bin" DOC "lcov - a graphical GCOV front-end" REQUIRED) +find_program(GCOV NAMES $ENV{GCOV_NAME} gcov HINTS "/usr" PATH_SUFFIXES "bin" DOC "GNU gcov binary" REQUIRED) +find_program(GENHTML NAMES genhtml HINTS "/usr" PATH_SUFFIXES "bin" DOC "genhtml - Generate HTML view from LCOV coverage data files" REQUIRED) + +if(NOT (LCOV AND GCOV) ) + message(FATAL_ERROR "locv or gcov not found! Aborting...") +endif() + +set(LCOV_FILTERS "'/usr/*';'/*/_deps/*';'/*/boost/*'") +set(LCOV_FLAGS "--rc;lcov_branch_coverage=1") +set(GENHTML_FLAGS "--branch-coverage;--rc;genhtml_med_limit=80;--rc;genhtml_hi_limit=95;--legend") + +message( STATUS "Code Coverage Module (LCOV)" ) + +add_custom_target( coverage-clean + COMMAND ${CMAKE_COMMAND} -E rm -rf coverage/ + COMMAND find . -name *.gcda -exec rm -v {} \; + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + COMMENT "Cleaning coverage related files" + VERBATIM +) + +add_custom_target( coverage-run + COMMAND ${CMAKE_MAKE_PROGRAM} coverage-clean + DEPENDS test_libblis.x + COMMAND test_libblis.x -g ${CMAKE_CURRENT_SOURCE_DIR}/input.general -o ${CMAKE_CURRENT_SOURCE_DIR}/input.operations > ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite + COMMENT "Code Coverage takes some time : Running test_libblis.x with output redirected to ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite" + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + VERBATIM +) + +add_custom_target( coverage-report + COMMAND ${CMAKE_MAKE_PROGRAM} coverage-run + COMMAND ${CMAKE_COMMAND} -E make_directory coverage/ + COMMAND ${LCOV} ${LCOV_FLAGS} -d .. -c -o coverage/coverage.info --gcov-tool ${GCOV} + COMMAND ${LCOV} ${LCOV_FLAGS} --remove coverage/coverage.info --gcov-tool ${GCOV} -o coverage/coverage_filtered.info ${LCOV_FILTERS} + COMMAND ${GENHTML} ${GENHTML_FLAGS} coverage/coverage_filtered.info --output coverage/html --title "AOCL-BLAS Code Coverage Report" + COMMENT "Building Code Coverage Report (LCOV)" + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + VERBATIM +) + +# Alias (only Makefile/Linux) +add_custom_target( coverage + DEPENDS coverage-report + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + VERBATIM +) diff --git a/testsuite/input.general b/testsuite/input.general index 7728402241..ae0d73b110 100644 --- a/testsuite/input.general +++ b/testsuite/input.general @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 500 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: diff --git a/testsuite/input.general.fast b/testsuite/input.general.fast index 02b30b897d..06a89d16d9 100644 --- a/testsuite/input.general.fast +++ b/testsuite/input.general.fast @@ -31,12 +31,7 @@ sdcz # Datatype(s) to test: 100 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) -0 # 1m ('1' = enable; '0' = disable) +1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: # '1' = disable / use one testsuite thread; diff --git a/testsuite/input.general.mixed b/testsuite/input.general.mixed index 55a3f56c75..36a3e62a67 100644 --- a/testsuite/input.general.mixed +++ b/testsuite/input.general.mixed @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 500 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 1 # Simulate application-level threading: diff --git a/testsuite/input.general.salt b/testsuite/input.general.salt index ad52b68bba..2e8b8a284e 100644 --- a/testsuite/input.general.salt +++ b/testsuite/input.general.salt @@ -31,11 +31,6 @@ sdcz # Datatype(s) to test: 100 # Problem size: maximum to test 100 # Problem size: increment between experiments # Complex level-3 implementations to test: -0 # 3mh ('1' = enable; '0' = disable) -0 # 3m1 ('1' = enable; '0' = disable) -0 # 4mh ('1' = enable; '0' = disable) -0 # 4m1b ('1' = enable; '0' = disable) -0 # 4m1a ('1' = enable; '0' = disable) 1 # 1m ('1' = enable; '0' = disable) 1 # native ('1' = enable; '0' = disable) 4 # Simulate application-level threading: diff --git a/testsuite/src/test_hemm.c b/testsuite/src/test_hemm.c index faff475969..433b0af34b 100644 --- a/testsuite/src/test_hemm.c +++ b/testsuite/src/test_hemm.c @@ -287,8 +287,6 @@ void libblis_test_hemm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_hemm( side, alpha, a, b, beta, c ); - //bli_hemm4m( side, alpha, a, b, beta, c ); - //bli_hemm3m( side, alpha, a, b, beta, c ); break; default: diff --git a/testsuite/src/test_her2k.c b/testsuite/src/test_her2k.c index 8c4b69d4cb..bef3cbb902 100644 --- a/testsuite/src/test_her2k.c +++ b/testsuite/src/test_her2k.c @@ -285,8 +285,6 @@ void libblis_test_her2k_impl { case BLIS_TEST_SEQ_FRONT_END: bli_her2k( alpha, a, b, beta, c ); - //bli_her2k4m( alpha, a, b, beta, c ); - //bli_her2k3m( alpha, a, b, beta, c ); break; default: diff --git a/testsuite/src/test_herk.c b/testsuite/src/test_herk.c index 0c9a8eb437..9212e6e69d 100644 --- a/testsuite/src/test_herk.c +++ b/testsuite/src/test_herk.c @@ -276,8 +276,6 @@ void libblis_test_herk_impl { case BLIS_TEST_SEQ_FRONT_END: bli_herk( alpha, a, beta, c ); - //bli_herk4m( alpha, a, beta, c ); - //bli_herk3m( alpha, a, beta, c ); break; default: diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 566701bfcc..269d673690 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -550,26 +550,6 @@ void libblis_test_read_params_file( char* input_filename, test_params_t* params libblis_test_read_next_line( buffer, input_stream ); sscanf( buffer, "%u ", &(params->p_inc) ); - // Read whether to enable 3mh. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_3MH ]) ); - - // Read whether to enable 3m1. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_3M1 ]) ); - - // Read whether to enable 4mh. - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4MH ]) ); - - // Read whether to enable 4m1b (4mb). - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4M1B ]) ); - - // Read whether to enable 4m1a (4m1). - libblis_test_read_next_line( buffer, input_stream ); - sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_4M1A ]) ); - // Read whether to enable 1m. libblis_test_read_next_line( buffer, input_stream ); sscanf( buffer, "%u ", &(params->ind_enable[ BLIS_1M ]) ); @@ -589,24 +569,13 @@ void libblis_test_read_params_file( char* input_filename, test_params_t* params // threads. if ( params->n_app_threads > 1 ) { - if ( params->ind_enable[ BLIS_3MH ] || - params->ind_enable[ BLIS_3M1 ] || - params->ind_enable[ BLIS_4MH ] || - params->ind_enable[ BLIS_4M1B ] || - params->ind_enable[ BLIS_4M1A ] || - params->ind_enable[ BLIS_1M ] - ) + if ( params->ind_enable[ BLIS_1M ] ) { // Due to an inherent race condition in the way induced methods // are enabled and disabled at runtime, all induced methods must be // disabled when simulating multiple application threads. libblis_test_printf_infoc( "simulating multiple application threads; disabling induced methods.\n" ); - params->ind_enable[ BLIS_3MH ] = 0; - params->ind_enable[ BLIS_3M1 ] = 0; - params->ind_enable[ BLIS_4MH ] = 0; - params->ind_enable[ BLIS_4M1B ] = 0; - params->ind_enable[ BLIS_4M1A ] = 0; params->ind_enable[ BLIS_1M ] = 0; } } @@ -1231,11 +1200,6 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) libblis_test_fprintf_c( os, "problem size: max to test %u\n", params->p_max ); libblis_test_fprintf_c( os, "problem size increment %u\n", params->p_inc ); libblis_test_fprintf_c( os, "complex implementations \n" ); - libblis_test_fprintf_c( os, " 3mh? %u\n", params->ind_enable[ BLIS_3MH ] ); - libblis_test_fprintf_c( os, " 3m1? %u\n", params->ind_enable[ BLIS_3M1 ] ); - libblis_test_fprintf_c( os, " 4mh? %u\n", params->ind_enable[ BLIS_4MH ] ); - libblis_test_fprintf_c( os, " 4m1b (4mb)? %u\n", params->ind_enable[ BLIS_4M1B ] ); - libblis_test_fprintf_c( os, " 4m1a (4m1)? %u\n", params->ind_enable[ BLIS_4M1A ] ); libblis_test_fprintf_c( os, " 1m? %u\n", params->ind_enable[ BLIS_1M ] ); libblis_test_fprintf_c( os, " native? %u\n", params->ind_enable[ BLIS_NAT ] ); libblis_test_fprintf_c( os, "simulated app-level threads %u\n", params->n_app_threads ); @@ -1790,8 +1754,8 @@ void libblis_test_op_driver } } - // Enumerate all combinations of datatype domains requested, but only - // for the gemm operation. + // Enumerate all combinations of datatypes requested, but only for the + // gemm operation. if ( !mixed_domain && mixed_precision && op->opid == BLIS_GEMM ) { @@ -2564,7 +2528,7 @@ void fill_string_with_n_spaces( char* str, unsigned int n_spaces ) { unsigned int i; - // Initialze to empty string in case n_spaces == 0. + // Initialize to empty string in case n_spaces == 0. sprintf( str, "%s", "" ); for ( i = 0; i < n_spaces; ++i ) diff --git a/testsuite/src/test_symm.c b/testsuite/src/test_symm.c index 50d6315bd2..87c6044f0f 100644 --- a/testsuite/src/test_symm.c +++ b/testsuite/src/test_symm.c @@ -287,8 +287,6 @@ void libblis_test_symm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_symm( side, alpha, a, b, beta, c ); - //bli_symm4m( side, alpha, a, b, beta, c ); - //bli_symm3m( side, alpha, a, b, beta, c ); break; default: diff --git a/testsuite/src/test_syr2k.c b/testsuite/src/test_syr2k.c index 7c86bd46fc..671e189f8a 100644 --- a/testsuite/src/test_syr2k.c +++ b/testsuite/src/test_syr2k.c @@ -285,8 +285,6 @@ void libblis_test_syr2k_impl { case BLIS_TEST_SEQ_FRONT_END: bli_syr2k( alpha, a, b, beta, c ); - //bli_syr2k4m( alpha, a, b, beta, c ); - //bli_syr2k3m( alpha, a, b, beta, c ); break; default: diff --git a/testsuite/src/test_syrk.c b/testsuite/src/test_syrk.c index e54471edec..c94b0bacdb 100644 --- a/testsuite/src/test_syrk.c +++ b/testsuite/src/test_syrk.c @@ -276,8 +276,6 @@ void libblis_test_syrk_impl { case BLIS_TEST_SEQ_FRONT_END: bli_syrk( alpha, a, beta, c ); - //bli_syrk4m( alpha, a, beta, c ); - //bli_syrk3m( alpha, a, beta, c ); break; default: diff --git a/testsuite/src/test_trmm.c b/testsuite/src/test_trmm.c index 24f00dc5b2..f31d6db9a4 100644 --- a/testsuite/src/test_trmm.c +++ b/testsuite/src/test_trmm.c @@ -272,8 +272,6 @@ void libblis_test_trmm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_trmm( side, alpha, a, b ); - //bli_trmm4m( side, alpha, a, b ); - //bli_trmm3m( side, alpha, a, b ); break; default: diff --git a/testsuite/src/test_trmm3.c b/testsuite/src/test_trmm3.c index 7c1789eb30..b50548b84f 100644 --- a/testsuite/src/test_trmm3.c +++ b/testsuite/src/test_trmm3.c @@ -288,8 +288,6 @@ void libblis_test_trmm3_impl { case BLIS_TEST_SEQ_FRONT_END: bli_trmm3( side, alpha, a, b, beta, c ); - //bli_trmm34m( side, alpha, a, b, beta, c ); - //bli_trmm33m( side, alpha, a, b, beta, c ); break; default: diff --git a/travis/do_sde.sh b/travis/do_sde.sh index efaf563b4b..c8eb5aa585 100755 --- a/travis/do_sde.sh +++ b/travis/do_sde.sh @@ -16,8 +16,16 @@ SDE=$SDE_VERSION/sde64 #curl --verbose --cookie jar.txt --output $SDE_TARBALL \ # https://software.intel.com/system/files/managed/2a/1a/$SDE_TARBALL -curl --verbose --output $SDE_TARBALL \ - https://software.intel.com/content/dam/develop/external/us/en/documents/downloads/$SDE_TARBALL +#curl --verbose --output $SDE_TARBALL \ +# https://software.intel.com/content/dam/develop/external/us/en/documents/downloads/$SDE_TARBALL + +CI_UTILS=ci-utils +CI_UTILS_URL=https://github.com/flame/${CI_UTILS}.git +CI_UTILS_SDE_DIR=sde +SDE_DIRPATH=$CI_UTILS/$CI_UTILS_SDE_DIR + +git clone $CI_UTILS_URL +mv $SDE_DIRPATH/$SDE_TARBALL . tar xvf $SDE_TARBALL @@ -37,7 +45,8 @@ for LIB in $LD_SO $LIBC_SO $LIBM_SO; do sudo mv .tmp $LIB done -for ARCH in penryn sandybridge haswell skx knl piledriver steamroller excavator zen; do +#for ARCH in penryn sandybridge haswell skx knl piledriver steamroller excavator zen; do +for ARCH in penryn sandybridge haswell skx knl zen; do if [ "$ARCH" = "knl" ]; then $SDE -knl -- ./test_libblis.x > output.testsuite else diff --git a/vendor/testcpp/CMakeLists.txt b/vendor/testcpp/CMakeLists.txt index 4e29b747ea..e14e6e5440 100644 --- a/vendor/testcpp/CMakeLists.txt +++ b/vendor/testcpp/CMakeLists.txt @@ -1,4 +1,36 @@ -##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## +#[=[ + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020 - 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +]=] # Comments: # - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. @@ -10,7 +42,7 @@ if(NOT DEFINED BLIS_INSTALL_PATH) set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) else() set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) - set(INC_PATH ${BLIS_INSTALL_PATH}/include/blis) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/${BLIS_CONFIG_FAMILY}) endif() # Include the corresponding make_defs.cmake that holds the required compiler options. @@ -51,9 +83,13 @@ foreach(source ${testcpp_sources}) ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/vendor/cpp ) - target_link_libraries(${exec_name} PRIVATE ${LDFLAGS} libblis) + target_link_libraries(${exec_name} PRIVATE ${LDFLAGS} ${libblis_link}) if(THREADING_MODEL STREQUAL "openmp") - target_link_libraries(${exec_name} PRIVATE OpenMP::OpenMP_C) + if((NOT ${OpenMP_libomp_LIBRARY} STREQUAL "") AND (NOT WIN32)) + target_link_libraries(${exec_name} PRIVATE ${OpenMP_libomp_LIBRARY}) + else() + target_link_libraries(${exec_name} PRIVATE OpenMP::OpenMP_C) + endif() endif() set_target_properties(${exec_name} PROPERTIES CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # Put all those targets under vendor-testcpp-targets folder name so that they appear all together in IDE. diff --git a/vendor/testcpp/Makefile b/vendor/testcpp/Makefile index 36b2726a2e..c723400d05 100644 --- a/vendor/testcpp/Makefile +++ b/vendor/testcpp/Makefile @@ -1,9 +1,11 @@ +# +# # BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2024, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/version b/version index 6aba2b245a..0062ac9718 100644 --- a/version +++ b/version @@ -1 +1 @@ -4.2.0 +5.0.0